Browse Source

Merge pull request #21 from fastnlp/dev/classify

Dev/classify
tags/v0.1.0
Coet GitHub 7 years ago
parent
commit
2075693273
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 660 additions and 628 deletions
  1. +102
    -8
      fastNLP/core/action.py
  2. +111
    -42
      fastNLP/core/inference.py
  3. +47
    -196
      fastNLP/core/tester.py
  4. +94
    -328
      fastNLP/core/trainer.py
  5. +3
    -2
      fastNLP/models/cnn_text_classification.py
  6. +8
    -29
      fastNLP/models/sequence_modeling.py
  7. +1
    -0
      fastNLP/modules/decoder/CRF.py
  8. +4
    -4
      reproduction/chinese_word_seg/cws_train.py
  9. +87
    -1
      test/data_for_tests/people.txt
  10. +100
    -0
      test/data_for_tests/text_classify.txt
  11. +12
    -11
      test/seq_labeling.py
  12. +5
    -5
      test/test_cws.py
  13. +2
    -2
      test/test_tester.py
  14. +84
    -0
      test/text_classify.py

+ 102
- 8
fastNLP/core/action.py View File

@@ -1,16 +1,111 @@
"""
This file defines Action(s) and sample methods.

"""
from collections import Counter

import numpy as np
import torch


class Action(object):
"""
base class for Trainer and Tester
Operations shared by Trainer, Tester, and Inference.
This is designed for reducing replicate codes.
- make_batch: produce a min-batch of data. @staticmethod
- pad: padding method used in sequence modeling. @staticmethod
- mode: change network mode for either train or test. (for PyTorch) @staticmethod
The base Action shall define operations shared by as much task-specific Actions as possible.
"""

def __init__(self):
super(Action, self).__init__()

@staticmethod
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:param use_cuda: bool
:param output_length: whether to output the original length of the sequence before padding.
:param max_len: int, maximum sequence length
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

return batch_x and batch_y, if output_length is False
"""
for indices in iterator:
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]

batch_x = Action.pad(batch_x)
# pad batch_y only if it is a 2-level list
if len(batch_y) > 0 and isinstance(batch_y[0], list):
batch_y = Action.pad(batch_y)

# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
batch_y = convert_to_torch_tensor(batch_y, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

if output_length:
seq_len = [len(x) for x in batch_x]
yield (batch_x, seq_len), batch_y
else:
yield batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length of this batch.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch

@staticmethod
def mode(model, test=False):
"""
Train mode or Test mode. This is for PyTorch currently.
:param model:
:param test:
"""
if test:
model.eval()
else:
model.train()


def convert_to_torch_tensor(data_list, use_cuda):
"""
convert lists into (cuda) Tensors
:param data_list: 2-level lists
:param use_cuda: bool
:param reqired_grad: bool
:return: PyTorch Tensor of shape [batch_size, max_seq_len]
"""
data_list = torch.Tensor(data_list).long()
if torch.cuda.is_available() and use_cuda:
data_list = data_list.cuda()
return data_list


def k_means_1d(x, k, max_iter=100):
"""
@@ -140,11 +235,10 @@ class Batchifier(object):

def __iter__(self):
batch = []
while True:
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if 0 < len(batch) < self.batch_size and self.drop_last is False:
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if 0 < len(batch) < self.batch_size and self.drop_last is False:
yield batch

+ 111
- 42
fastNLP/core/inference.py View File

@@ -1,7 +1,45 @@
import numpy as np
import torch

from fastNLP.core.action import Batchifier, SequentialSampler
from fastNLP.core.action import convert_to_torch_tensor
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils


def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None):
for indices in iterator:
batch_x = [data[idx] for idx in indices]
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if min_len is not None and batch_x.size(1) < min_len:
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
batch_x = torch.cat((batch_x, pad_tensor), 1)

if output_length:
seq_len = [len(x) for x in batch_x]
yield tuple([batch_x, seq_len])
else:
yield batch_x


def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch


class Inference(object):
@@ -9,7 +47,8 @@ class Inference(object):
This is an interface focusing on predicting output based on trained models.
It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP.

This class does not share any operations with Trainer and Tester.
Currently, Inference does not support GPU.
"""

def __init__(self, pickle_path):
@@ -32,13 +71,11 @@ class Inference(object):

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.batch_output.clear()

self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))

num_iter = len(data) // self.batch_size
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))

for step in range(num_iter):
batch_x = self.make_batch(data)
for batch_x in self.make_batch(iterator, data, use_cuda=False):

prediction = self.data_forward(network, batch_x)

@@ -51,43 +88,12 @@ class Inference(object):
network.eval()
else:
network.train()
self.batch_output.clear()

def data_forward(self, network, x):
"""
This is only for sequence labeling with CRF decoder. TODO: more general ?
:param network:
:param x:
:return:
"""
seq_len = [len(seq) for seq in x]
x = torch.Tensor(x).long()
y = network(x)
prediction = network.prediction(y, seq_len)
# To do: hide framework
results = torch.Tensor(prediction).view(-1, )
return list(results.data)
raise NotImplementedError

def make_batch(self, data):
indices = next(self.iterator)
batch_x = [data[idx] for idx in indices]
if self.batch_size > 1:
batch_x = self.pad(batch_x)
return batch_x

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
return batch
def make_batch(self, iterator, data, use_cuda):
raise NotImplementedError

def prepare_input(self, data):
"""
@@ -106,13 +112,76 @@ class Inference(object):
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
return data_index

def prepare_output(self, data):
raise NotImplementedError


class SeqLabelInfer(Inference):
"""
Inference on sequence labeling models.
"""

def __init__(self, pickle_path):
super(SeqLabelInfer, self).__init__(pickle_path)

def data_forward(self, network, inputs):
"""
This is only for sequence labeling with CRF decoder.
:param network:
:param inputs:
:return: Tensor
"""
if not isinstance(inputs[1], list) and isinstance(inputs[0], list):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction, required_grad=False)

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=True)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return:
"""
results = []
for batch in batch_outputs:
results.append([self.index2label[int(x.data)] for x in batch])
for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example])
return results


class ClassificationInfer(Inference):
"""
Inference on Classification models.
"""

def __init__(self, pickle_path):
super(ClassificationInfer, self).__init__(pickle_path)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return:
"""
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy())
results.append(self.index2label[idx])
return results

+ 47
- 196
fastNLP/core/tester.py View File

@@ -1,14 +1,14 @@
import _pickle
import os

import numpy as np
import torch

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.modules import utils


class BaseTester(Action):
class BaseTester(object):
"""docstring for Tester"""

def __init__(self, test_args):
@@ -37,25 +37,33 @@ class BaseTester(Action):
else:
self.model = network

# no backward setting for model
for param in network.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.eval_history.clear()
self.batch_output.clear()

dev_data = self.prepare_input(self.pickle_path)

self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))

num_iter = len(dev_data) // self.batch_size
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))
n_batches = len(dev_data) // self.batch_size
n_print = 1
step = 0

for step in range(num_iter):
batch_x, batch_y = self.make_batch(dev_data)
for batch_x, batch_y in self.make_batch(iterator, dev_data):

prediction = self.data_forward(network, batch_x)

eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)
step += 1

def prepare_input(self, data_path):
"""
@@ -64,51 +72,14 @@ class BaseTester(Action):
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
"""
if self.save_dev_data is None:
data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
self.save_dev_data = data_dev
return self.save_dev_data

def make_batch(self, data, output_length=True):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]
"""
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x_pad = self.pad(batch_x)
batch_y_pad = self.pad(batch_y)
if output_length:
seq_len = [len(x) for x in batch_x]
return (batch_x_pad, seq_len), batch_y_pad
else:
return batch_x_pad, batch_y_pad

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
def mode(self, model, test):
Action.mode(model, test)

def data_forward(self, network, data):
def data_forward(self, network, x):
raise NotImplementedError

def evaluate(self, predict, truth):
@@ -118,14 +89,6 @@ class BaseTester(Action):
def metrics(self):
raise NotImplementedError

def mode(self, model, test=True):
"""TODO: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()

def show_matrices(self):
"""
This is called by Trainer to print evaluation on dev set.
@@ -133,8 +96,11 @@ class BaseTester(Action):
"""
raise NotImplementedError

def make_batch(self, iterator, data):
raise NotImplementedError


class POSTester(BaseTester):
class SeqLabelTester(BaseTester):
"""
Tester for sequence labeling.
"""
@@ -143,44 +109,36 @@ class POSTester(BaseTester):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
super(POSTester, self).__init__(test_args)
super(SeqLabelTester, self).__init__(test_args)
self.max_len = None
self.mask = None
self.batch_result = None

def data_forward(self, network, inputs):
"""TODO: combine with Trainer

:param network: the PyTorch model
:param x: list of list, [batch_size, max_len]
:return y: [batch_size, num_classes]
"""
if not isinstance(inputs, tuple):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
if isinstance(inputs, tuple):
x = inputs[0]
self.seq_len = inputs[1]
else:
x = inputs
x = torch.Tensor(x).long()
x, seq_len = inputs[0], inputs[1]
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
if torch.cuda.is_available() and self.use_cuda:
x = x.cuda()
self.batch_size = x.size(0)
self.max_len = x.size(1)
mask = mask.cuda()
self.mask = mask

y = network(x)
return y

def evaluate(self, predict, truth):
truth = torch.Tensor(truth)
if torch.cuda.is_available() and self.use_cuda:
truth = truth.cuda()
loss = self.model.loss(predict, truth, self.seq_len) / self.batch_size
prediction = self.model.prediction(predict, self.seq_len)
batch_size, max_len = predict.size(0), predict.size(1)
loss = self.model.loss(predict, truth, self.mask) / batch_size

prediction = self.model.prediction(predict, self.mask)
results = torch.Tensor(prediction).view(-1,)
if torch.cuda.is_available() and self.use_cuda:
results = results.cuda()
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0]
return [loss.data, accuracy]
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
return [loss.data, accuracy.data]

def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history])
@@ -195,8 +153,11 @@ class POSTester(BaseTester):
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)


class ClassTester(BaseTester):
class ClassificationTester(BaseTester):
"""Tester for classification."""

def __init__(self, test_args):
@@ -204,7 +165,7 @@ class ClassTester(BaseTester):
:param test_args: a dict-like object that has __getitem__ method, \
can be accessed by "test_args["key_str"]"
"""
# super(ClassTester, self).__init__()
super(ClassificationTester, self).__init__(test_args)
self.pickle_path = test_args["pickle_path"]

self.save_dev_data = None
@@ -212,111 +173,8 @@ class ClassTester(BaseTester):
self.mean_loss = None
self.iterator = None

if "test_name" in test_args:
self.test_name = test_args["test_name"]
else:
self.test_name = "data_test.pkl"

if "validate_in_training" in test_args:
self.validate_in_training = test_args["validate_in_training"]
else:
self.validate_in_training = False

if "save_output" in test_args:
self.save_output = test_args["save_output"]
else:
self.save_output = False

if "save_loss" in test_args:
self.save_loss = test_args["save_loss"]
else:
self.save_loss = True

if "batch_size" in test_args:
self.batch_size = test_args["batch_size"]
else:
self.batch_size = 50
if "use_cuda" in test_args:
self.use_cuda = test_args["use_cuda"]
else:
self.use_cuda = True

if "max_len" in test_args:
self.max_len = test_args["max_len"]
else:
self.max_len = None

self.model = None
self.eval_history = []
self.batch_output = []

def test(self, network):
# prepare model
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network

# no backward setting for model
for param in self.model.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)

# prepare test data
data_test = self.prepare_input(self.pickle_path, self.test_name)

# data generator
self.iterator = iter(Batchifier(
RandomSampler(data_test), self.batch_size, drop_last=False))

# test
n_batches = len(data_test) // self.batch_size
n_print = n_batches // 10
step = 0
for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len):
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)

if step % n_print == 0:
print("step: {:>5}".format(step))

step += 1

def prepare_input(self, data_dir, file_name):
"""Prepare data."""
file_path = os.path.join(data_dir, file_name)
with open(file_path, 'rb') as f:
data = _pickle.load(f)
return data

def make_batch(self, data, max_len=None):
"""Batch and pad data."""
for indices in self.iterator:
# generate batch and pad
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

# convert to tensor
batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

yield batch_x, batch_y
def make_batch(self, iterator, data, max_len=None):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len)

def data_forward(self, network, x):
"""Forward through network."""
@@ -337,10 +195,3 @@ class ClassTester(BaseTester):
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def mode(self, model, test=True):
"""TODO: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()

+ 94
- 328
fastNLP/core/trainer.py View File

@@ -8,20 +8,18 @@ import torch
import torch.nn as nn

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler
from fastNLP.core.tester import POSTester
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.model_saver import ModelSaver


class BaseTrainer(Action):
class BaseTrainer(object):
"""Base trainer for all trainers.
Trainer receives a model and data, and then performs training.

Subclasses must implement the following abstract methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- grad_backward
- get_loss
"""
@@ -75,25 +73,29 @@ class BaseTrainer(Action):
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path)

# define tester over dev data
# TODO: more flexible
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
validator = POSTester(valid_args)
if self.validate:
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
validator = self._create_validator(default_valid_args)

# main training epochs
iterations = len(data_train) // self.batch_size
self.define_optimizer()

# main training epochs
start = time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = 1

for epoch in range(1, self.n_epochs + 1):

# turn on network training mode; define optimizer; prepare batch iterator
self.mode(test=False)
self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True))
# turn on network training mode; prepare batch iterator
self.mode(network, test=False)
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
for step in range(iterations):
batch_x, batch_y = self.make_batch(data_train)
step = 0
for batch_x, batch_y in self.make_batch(iterator, data_train):

prediction = self.data_forward(network, batch_x)

@@ -101,12 +103,14 @@ class BaseTrainer(Action):
self.grad_backward(loss)
self.update()

if step % 10 == 0:
print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data))
if step % n_print == 0:
end = time()
diff = timedelta(seconds=round(end - start))
print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
epoch, step, loss.data, diff))
step += 1

if self.validate:
if data_dev is None:
raise RuntimeError("No validation data provided.")
validator.test(network)

if self.save_best_dev and self.best_eval_result(validator):
@@ -116,22 +120,32 @@ class BaseTrainer(Action):
print("[epoch {}]".format(epoch), end=" ")
print(validator.show_matrices())

# finish training

def prepare_input(self, data_path):
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb"))
data_test = _pickle.load(open(data_path + "data_test.pkl", "rb"))
embedding = _pickle.load(open(data_path + "embedding.pkl", "rb"))
return data_train, data_dev, data_test, embedding

def mode(self, test=False):
def prepare_input(self, pickle_path):
"""
Tell the network to be trained or not.
:param test: bool
For task-specific processing.
:param pickle_path:
:return data_train, data_dev, data_test, embedding:
"""
names = [
"data_train.pkl", "data_dev.pkl",
"data_test.pkl", "embedding.pkl"]
files = []
for name in names:
file_path = os.path.join(pickle_path, name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
data = []
files.append(data)
return tuple(files)

def make_batch(self, iterator, data):
raise NotImplementedError

def mode(self, network, test):
Action.mode(network, test)

def define_optimizer(self):
"""
Define framework-specific optimizer specified by the models.
@@ -147,14 +161,6 @@ class BaseTrainer(Action):
raise NotImplementedError

def data_forward(self, network, x):
"""
Forward pass of the data.
:param network: a model
:param x: input feature matrix and label vector
:return: output by the models

For PyTorch, just do "network(*x)"
"""
raise NotImplementedError

def grad_backward(self, loss):
@@ -187,50 +193,6 @@ class BaseTrainer(Action):
"""
raise NotImplementedError

def make_batch(self, data, output_length=True):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

return batch_x and batch_y, if output_length is False
"""
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x_pad = self.pad(batch_x)
batch_y_pad = self.pad(batch_y)
if output_length:
seq_len = [len(x) for x in batch_x]
return (batch_x_pad, seq_len), batch_y_pad
else:
return batch_x_pad, batch_y_pad

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch

def best_eval_result(self, validator):
"""
:param validator: a Tester instance
@@ -245,6 +207,9 @@ class BaseTrainer(Action):
"""
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network)

def _create_validator(self, valid_args):
raise NotImplementedError


class ToyTrainer(BaseTrainer):
"""
@@ -259,12 +224,6 @@ class ToyTrainer(BaseTrainer):
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1

def mode(self, test=False):
if test:
self.model.eval()
else:
self.model.train()

def data_forward(self, network, x):
return network(x)

@@ -282,53 +241,20 @@ class ToyTrainer(BaseTrainer):
self.optimizer.step()


class POSTrainer(BaseTrainer):
class SeqLabelTrainer(BaseTrainer):
"""
Trainer for Sequence Modeling

"""

def __init__(self, train_args):
super(POSTrainer, self).__init__(train_args)
super(SeqLabelTrainer, self).__init__(train_args)
self.vocab_size = train_args["vocab_size"]
self.num_classes = train_args["num_classes"]
self.max_len = None
self.mask = None
self.best_accuracy = 0.0

def prepare_input(self, data_path):

data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1

def data_forward(self, network, inputs):
"""
:param network: the PyTorch model
:param inputs: list of list, [batch_size, max_len],
or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len]
:return y: [batch_size, max_len, tag_size]
"""
# unpack the returned value from make_batch
if isinstance(inputs, tuple):
x = inputs[0]
self.seq_len = inputs[1]
else:
x = inputs
x = torch.Tensor(x).long()
if torch.cuda.is_available() and self.use_cuda:
x = x.cuda()
self.batch_size = x.size(0)
self.max_len = x.size(1)

y = network(x)
return y

def mode(self, test=False):
if test:
self.model.eval()
else:
self.model.train()

def define_optimizer(self):
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9)

@@ -339,6 +265,23 @@ class POSTrainer(BaseTrainer):
def update(self):
self.optimizer.step()

def data_forward(self, network, inputs):
if not isinstance(inputs, tuple):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]

batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)

if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask

y = network(x)
return y

def get_loss(self, predict, truth):
"""
Compute loss given prediction and ground truth.
@@ -346,17 +289,10 @@ class POSTrainer(BaseTrainer):
:param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar
"""
truth = torch.Tensor(truth)
if torch.cuda.is_available() and self.use_cuda:
truth = truth.cuda()
assert truth.shape == (self.batch_size, self.max_len)
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.define_loss()
loss = self.loss_func(predict, truth, self.seq_len)
# print("loss={:.2f}".format(loss.data))
batch_size, max_len = predict.size(0), predict.size(1)
assert truth.shape == (batch_size, max_len)

loss = self.model.loss(predict, truth, self.mask)
return loss

def best_eval_result(self, validator):
@@ -367,62 +303,18 @@ class POSTrainer(BaseTrainer):
else:
return False

def make_batch(self, data, output_length=True):
"""
1. Perform batching from data and produce a batch of training data.
2. Add padding.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
E.g.
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels]

return batch_x and batch_y, if output_length is False
"""
indices = next(self.iterator)
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x_pad = self.pad(batch_x)
batch_y_pad = self.pad(batch_y)
if output_length:
seq_len = [len(x) for x in batch_x]
return (batch_x_pad, seq_len), batch_y_pad
else:
return batch_x_pad, batch_y_pad

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda)

class LanguageModelTrainer(BaseTrainer):
"""
Trainer for Language Model
"""
def _create_validator(self, valid_args):
return SeqLabelTester(valid_args)

def __init__(self, train_args):
super(LanguageModelTrainer, self).__init__(train_args)

def prepare_input(self, data_path):
pass


class ClassTrainer(BaseTrainer):
class ClassificationTrainer(BaseTrainer):
"""Trainer for classification."""

def __init__(self, train_args):
# super(ClassTrainer, self).__init__(train_args)
self.n_epochs = train_args["epochs"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]

if "validate" in train_args:
self.validate = train_args["validate"]
else:
self.validate = False
super(ClassificationTrainer, self).__init__(train_args)
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
@@ -431,127 +323,14 @@ class ClassTrainer(BaseTrainer):
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
if "use_cuda" in train_args:
self.use_cuda = train_args["use_cuda"]
else:
self.use_cuda = True

self.model = None
self.iterator = None
self.loss_func = None
self.optimizer = None

def train(self, network):
"""General Training Steps
:param network: a model

The method is framework independent.
Work by calling the following methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- get_loss
- grad_backward
- update
Subclasses must implement these methods with a specific framework.
"""
# prepare model and data, transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network
data_train, data_dev, data_test, embedding = self.prepare_input(
self.pickle_path)

# define tester over dev data
# valid_args = {
# "save_output": True, "validate_in_training": True,
# "save_dev_input": True, "save_loss": True,
# "batch_size": self.batch_size, "pickle_path": self.pickle_path}
# validator = POSTester(valid_args)

# urn on network training mode, define loss and optimizer
self.define_loss()
self.define_optimizer()
self.mode(test=False)

# main training epochs
start = time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = n_batches // 10
for epoch in range(self.n_epochs):
# prepare batch iterator
self.iterator = iter(Batchifier(
RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
step = 0
for batch_x, batch_y in self.make_batch(data_train):
prediction = self.data_forward(network, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

if step % n_print == 0:
acc = self.get_acc(prediction, batch_y)
end = time()
diff = timedelta(seconds=round(end - start))
print("epoch: {:>3} step: {:>4} loss: {:>4.2}"
" train acc: {:>5.1%} time: {}".format(
epoch, step, loss, acc, diff))

step += 1

# if self.validate:
# if data_dev is None:
# raise RuntimeError("No validation data provided.")
# validator.test(network)
# print("[epoch {}]".format(epoch), end=" ")
# print(validator.show_matrices())

# finish training

def prepare_input(self, data_path):

names = [
"data_train.pkl", "data_dev.pkl",
"data_test.pkl", "embedding.pkl"]

files = []
for name in names:
file_path = os.path.join(data_path, name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
data = []
files.append(data)

return tuple(files)

def mode(self, test=False):
"""
Tell the network to be trained or not.
:param test: bool
"""
if test:
self.model.eval()
else:
self.model.train()
self.best_accuracy = 0

def define_loss(self):
"""
Assign an instance of loss function to self.loss_func
E.g. self.loss_func = nn.CrossEntropyLoss()
"""
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.loss_func = nn.CrossEntropyLoss()
self.loss_func = nn.CrossEntropyLoss()

def define_optimizer(self):
"""
@@ -567,10 +346,6 @@ class ClassTrainer(BaseTrainer):
logits = network(x)
return logits

def get_loss(self, predict, truth):
"""Calculate loss."""
return self.loss_func(predict, truth)

def grad_backward(self, loss):
"""Compute gradient backward."""
self.model.zero_grad()
@@ -580,30 +355,21 @@ class ClassTrainer(BaseTrainer):
"""Apply gradient."""
self.optimizer.step()

def make_batch(self, data):
"""Batch and pad data."""
for indices in self.iterator:
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

yield batch_x, batch_y
def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda)

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)

def best_eval_result(self, validator):
_, _, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False

if __name__ == "__name__":
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args)
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
trainer.make_batch(data=data_train)
def _create_validator(self, valid_args):
return ClassificationTester(valid_args)

+ 3
- 2
fastNLP/models/cnn_text_classification.py View File

@@ -1,13 +1,14 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn

# import torch.nn.functional as F
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool


class CNNText(BaseModel):
class CNNText(torch.nn.Module):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence


+ 8
- 29
fastNLP/models/sequence_modeling.py View File

@@ -1,7 +1,7 @@
import torch

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder, utils
from fastNLP.modules import decoder, encoder


class SeqLabeling(BaseModel):
@@ -34,46 +34,25 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, num_classes]
return x

def loss(self, x, y, seq_length):
def loss(self, x, y, mask):
"""
Negative log likelihood loss.
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param y: LongTensor, [batch_size, max_len]
:param seq_length: list of int. [batch_size]
:param x: Tensor, [batch_size, max_len, tag_size]
:param y: Tensor, [batch_size, max_len]
:param mask: ByteTensor, [batch_size, ,max_len]
:return loss: a scalar Tensor

"""
x = x.float()
y = y.long()

batch_size = x.size(0)
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
mask = mask.byte().view(batch_size, max_len)

# TODO: remove
if torch.cuda.is_available():
mask = mask.cuda()
# mask = x.new(batch_size, max_len)

total_loss = self.Crf(x, y, mask)

return torch.mean(total_loss)

def prediction(self, x, seq_length):
def prediction(self, x, mask):
"""
:param x: FloatTensor, [batch_size, max_len, tag_size]
:param seq_length: int
:return prediction: list of tuple of (decode path(list), best score)
:param mask: ByteTensor, [batch_size, max_len]
:return prediction: list of [decode path(list)]
"""
x = x.float()
max_len = x.size(1)

mask = utils.seq_mask(seq_length, max_len)
# hack: make sure mask has the same device as x
mask = mask.to(x).byte()

tag_seq = self.Crf.viterbi_decode(x, mask)

return tag_seq

+ 1
- 0
fastNLP/modules/decoder/CRF.py View File

@@ -132,6 +132,7 @@ class ConditionalRandomField(nn.Module):
Given a feats matrix, return best decode path and best score.
:param feats:
:param masks:
:param get_score: bool, whether to output the decode score.
:return:List[Tuple(List, float)],
"""
batch_size, max_len, tag_size = feats.size()


+ 4
- 4
reproduction/chinese_word_seg/cws_train.py View File

@@ -3,12 +3,12 @@ import sys
sys.path.append("..")

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import POSTrainer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference

@@ -64,7 +64,7 @@ def train():
train_args["num_classes"] = p.num_classes

# Trainer
trainer = POSTrainer(train_args)
trainer = SeqLabelTrainer(train_args)

# Model
model = SeqLabeling(train_args)
@@ -96,7 +96,7 @@ def test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)


+ 87
- 1
test/data_for_tests/people.txt View File

@@ -64,4 +64,90 @@
3 B-t
1 M-t
日 E-t
, S-w
, S-w
迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

+ 100
- 0
test/data_for_tests/text_classify.txt View File

@@ -0,0 +1,100 @@
entertainment 台 媒 预 测 周 冬 雨 金 马 奖 封 后 , 大 气 的 倪 妮 却 佳 作 难 出
food 农 村 就 是 好 , 能 吃 到 纯 天 然 无 添 加 的 野 生 蜂 蜜 , 营 养 又 健 康
fashion 1 4 款 知 性 美 装 , 时 尚 惊 艳 搁 浅 的 阳 光 轻 熟 的 优 雅
history 火 焰 喷 射 器 1 0 0 0 度 火 焰 烧 死 鬼 子 4 连 拍
society 1 8 岁 青 年 砍 死 8 8 岁 老 兵
fashion 醋 洗 脸 的 正 确 方 法 洗 对 了 不 仅 美 容 肌 肤 还 能 收 缩 毛 孔
game 大 家 都 说 说 除 了 这 1 0 个 英 雄 , L O L 还 有 哪 些 英 雄 可 以 单 挑 男 爵
sports 王 仕 鹏 退 役 担 任 N B A 总 决 赛 现 场 解 说 嘉 宾
regimen 天 天 吃 “ 洋 快 餐 ” , 5 岁 女 童 患 上 肝 炎
food 汤 里 的 蛋 花 怎 样 才 能 如 花 朵 般 漂 亮 , 注 意 这 一 点 即 可 !
tech 英 退 休 人 士 把 谷 歌 当 活 人 以 礼 貌 搜 索 请 求 征 服 整 个 互 联 网
discovery N A S A 探 测 器 拍 摄 地 球 、 火 星 和 冥 王 星 合 影
society 当 骗 子 遇 上 撒 贝 宁 ! 几 句 话 过 后 骗 子 赔 礼 道 歉 . . . . .
history 红 军 长 征 在 中 国 革 命 史 上 的 地 位
world 实 拍 神 秘 之 国 , 带 你 走 进 真 实 的 朝 鲜
tech 逼 格 爆 表 ! 古 文 版 2 0 1 6 网 络 热 词 : 燃 尽 洪 荒 之 力
story 因 为 一 样 东 西 这 个 后 娘 竟 然 给 孩 子 磕 头
game L O L : 皮 肤 对 操 作 没 影 响 ? 细 数 那 些 有 加 成 效 果 的 皮 肤
fashion 冬 天 想 穿 裙 子 又 怕 冷 ? 学 了 这 些 搭 配 就 能 好 看 又 温 暖 !
entertainment 贾 建 军 少 林 三 光 剑 视 频
food 再 也 不 用 出 去 吃 羊 肉 串 , 自 己 做 又 卫 生 又 健 康
regimen 男 人 多 吃 这 几 道 菜 , 效 果 胜 “ 伟 哥 ”
baby 宝 贝 厨 房 丨 肉 类 辅 食 第 一 步 宝 宝 的 生 长 发 育 每 天 都 离 不 开 它 !
travel 近 8 0 亿 的 顶 级 豪 华 邮 轮 上 到 底 有 什 么 ?
sports 厄 齐 尔 心 中 最 想 签 约 的 三 个 人
food 东 北 的 粘 豆 包 啊 , 想 死 你 们 了 !
military 强 军 足 音
sports 奥 运 赛 场 上 , 被 喷 子 痛 批 的 十 大 知 名 运 动 员
game 老 玩 家 分 享 对 2 0 1 6 L P L 夏 季 赛 R N G 的 分 析
military 揭 秘 : 关 于 战 争 的 五 大 真 相 , 不 要 再 被 影 视 所 欺 骗 了 !
food 小 丫 厨 房 : 夏 天 怎 么 吃 辣 不 长 痘 ? 告 诉 你 火 锅 鸡 、 香 辣 鱼 的 正 确 做 法
travel 中 国 首 个 内 陆 城 市 群 上 的 9 座 城 市 , 看 看 有 你 的 家 乡 吗
fashion 李 小 璐 做 榜 样 接 亲 吻 脚 大 流 行 新 娘 玉 足 怎 样 才 有 好 味 道 ?
game 黄 金 吊 打 钻 石 ? L O L 最 强 刷 钱 毒 瘤 打 法 诞 生
history 奇 事 ! 上 万 只 青 蛙 拦 路 告 状 , 竟 然 牵 扯 出 一 桩 命 案
baby 奶 奶 , 你 为 什 么 不 让 我 用 尿 不 湿
game L O L 当 5 个 大 发 明 家 炮 台 围 住 泉 水 的 时 候 : 这 是 真 虐 泉 !
essay 文 友 忠 告 暖 人 心 : 人 到 中 年 “ 不 交 五 友 ”
travel 这 一 年 , 我 们 去 日 本
food 好 吃 早 饭 近 似 吃 补 药
fashion 夏 天 太 热 , 唇 膏 化 了 如 何 办 ?
society 厂 里 面 的 9 0 后 打 工 妹 , 辛 苦 来 之 不 易
history 罕 见 老 照 片 展 示 美 国 大 萧 条 时 期 景 象
world 美 国 总 统 奥 巴 马 , 是 童 心 未 泯 的 温 情 奥 大 大 , 还 是 个 超 级 老 顽 童
finance 脱 欧 公 投 前 一 天 抛 售 英 镑 这 一 次 索 罗 斯 也 被 “ 打 败 ” 了 . . .
history 翻 越 长 征 路 上 第 一 座 大 山
world 朝 鲜 批 奥 巴 马 涉 朝 言 论 , 称 只 要 核 威 胁 存 在 将 继 续 强 化 核 武 力 量
game 《 巫 师 3 : 狂 猎 》 不 良 因 素 解 析 攻 略
travel 在 郑 州 有 个 地 方 , 时 光 仿 佛 在 那 儿 停 下 脚 步
history 它 号 称 “ 天 下 第 一 团 ” , 走 出 过 1 4 位 共 和 国 将 军 以 及 一 位 著 名 作 家
car 煤 老 板 去 黄 江 买 车 , 以 为 占 了 便 宜 没 想 被 坑 了 1 0 0 多 万
society “ 试 管 婴 儿 之 母 ” 张 丽 珠 遗 体 告 别 仪 式 8 日 举 行
sports 东 京 奥 运 会 , 中 国 女 排 卫 冕 的 几 率 有 多 大 ?
travel 成 都 我 们 永 远 依 恋 的 城 市
tech 雷 布 斯 除 了 小 米 还 有 这 些 秘 密 , 你 知 道 吗 ?
world “ 仲 裁 庭 损 害 国 际 法 体 系 公 正 性 ” — — 访 武 汉 大 学 中 国 边 界 与 海 洋 研 究 院 首 席 专 家 易 显 河
entertainment 上 海 观 众 和 欧 洲 三 大 影 展 之 间 的 距 离 : 零 时 差
essay 关 系 好 , 一 切 便 好
baby 刚 出 生 不 到 1 小 时 的 白 鲸 宝 宝 被 冲 上 岸 , 被 救 后 对 恩 人 露 出 微 笑
tech 赚 足 眼 球 , 诺 基 亚 五 边 形 W i n 1 0 M o b i l e 概 念 手 机 : 棱 镜
essay 2 4 句 经 典 语 录 : 穷 三 年 可 以 怨 命 , 穷 十 年 就 得 自 省
food 这 道 菜 真 下 饭 ! 做 法 简 单 , 防 辐 射 、 抗 衰 老 , 关 键 还 便 宜
entertainment 《 继 承 者 们 》 要 拍 中 国 版 , 众 角 色 你 期 待 谁 来 演 ?
game D N F 暴 走 改 版 后 怎 么 样 D N F 暴 走 改 版 红 眼 变 弱 了 吗
entertainment 郑 佩 佩 自 曝 与 李 小 龙 的 过 去 他 是 个 “ 疯 子 ”
baby 女 性 只 有 8 4 次 最 佳 受 孕 机 会
travel 月 初 一 个 人 去 了 日 本 . .
military 不 为 人 知 的 8 0 万 苏 联 女 兵 ! 最 后 一 张 很 美 !
tech 网 络 商 家 提 供 小 米 5 运 存 升 级 服 务 : 3 G B 秒 变 6 G B
history 宋 太 祖 、 宋 太 宗 凌 辱 亡 国 皇 后 , 徽 钦 二 帝 后 宫 被 金 人 凌 辱
history 人 有 三 面 最 “ 难 吃 ” ! 黑 帮 大 佬 杜 月 笙 论 江 湖 规 矩 ! 一 生 只 怕 这 一 人
game 来 了 ! 索 尼 P S 4 独 占 大 作 《 战 神 4 》 正 式 公 布
discovery 延 时 视 频 显 示 珊 瑚 如 何 “ 驱 逐 ” 共 生 藻 类
car 传 祺 G A 8 和 东 风 A 9 谁 才 是 自 主 “ 豪 车 ” 大 佬
fashion 娶 老 婆 就 要 娶 这 种 ! 蒋 欣 这 样 微 胖 的 女 人 好 看 又 实 用
sports 黄 山 姑 娘 吕 秀 芝 勇 夺 奥 运 铜 牌 数 百 父 老 彻 夜 为 她 加 油
military [ 每 日 军 图 ] 土 豪 补 仓 ! 沙 特 再 次 购 买 上 百 辆 美 国 M 1 A 2 主 战 坦 克
military 美 军 这 款 武 器 号 称 能 让 半 个 中 国 陷 入 黑 暗 , 解 放 军 少 将 : 我 们 也 有
world 邓 小 平 与 日 本 天 皇 的 历 史 性 会 谈 , 对 中 日 两 国 都 具 有 深 远 的 意 义 啊 !
baby 为 什 么 有 人 上 个 厕 所 都 能 生 出 孩 子 ?
fashion 欣 宜 举 行 首 次 个 唱 十 万 颗 宝 仕 奥 莎 仿 水 晶 闪 耀 全 场
food 小 两 口 上 周 的 晚 餐
society 在 北 京 就 要 守 规 矩
entertainment 知 情 人 曝 翰 爽 分 手 内 幕 : 郑 爽 想 结 婚 却 被 一 直 拖 着
military 中 国 反 舰 导 弹 世 界 第 一 远 远 超 过 美 国 但 为 何 却 还 不 如 俄 罗 斯 ?
entertainment 他 除 了 是 《 我 歌 》 音 乐 总 监 , 还 曾 组 乐 队 玩 摇 滚 , 是 黄 家 驹 旧 日 知 己
baby 长 鹅 口 疮 的 孩 子 怎 么 照 顾 ? 不 要 再 说 拿 他 没 办 法 了 !
discovery 微 重 力 不 需 使 用 肌 肉 , 太 空 人 返 回 地 球 后 脊 椎 旁 肌 肉 萎 缩 约 1 9 %
regimen 这 6 种 人 将 来 会 得 老 年 痴 呆 ! 预 防 老 年 痴 呆 症 , 这 些 办 法 被 全 世 界 公 认
society 2 0 1 6 年 上 海 即 将 发 生 哪 些 大 事 件 。 。 。 。
car 北 汽 自 主 品 牌 亏 损 3 3 . 4 1 亿 额 外 促 销 成 主 因
car 在 那 山 的 那 边 海 的 那 边 , 有 一 群 自 由 侠
history 一 个 小 城 就 屠 杀 了 4 0 0 0 苏 军 战 俘 , 希 特 勒 死 神 战 队 的 崛 起 与 覆 灭
baby 给 孩 子 洗 澡 时 , 这 些 部 位 再 脏 也 不 要 碰 !
essay 好 久 不 见 , 你 还 好 么
baby 被 娃 误 伤 的 9 种 痛 , 数 一 数 你 中 了 几 枪 ?
food 初 秋 的 小 炖 品 放 冰 糖 就 比 较 滋 润 , 放 红 糖 就 补 血 又 不 燥 热
game 佩 服 佩 服 ! 羊 驼 D e f t 单 排 重 回 韩 服 最 强 王 者 第 一 名 !
game 三 个 时 代 的 标 志 炉 石 传 说 三 大 远 古 毒 瘤 卡 组
discovery 2 0 世 纪 最 伟 大 科 学 发 现 — — 魔 术 般 的 超 导 材 料 !

test/test_seq_labeling.py → test/seq_labeling.py View File

@@ -3,14 +3,14 @@ import sys
sys.path.append("..")

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import POSTrainer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference
from fastNLP.core.inference import SeqLabelInfer

data_name = "people.txt"
data_path = "data_for_tests/people.txt"
@@ -50,14 +50,15 @@ def infer():
"""

# Inference interface
infer = Inference(pickle_path)
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)

print(results)
for res in results:
print(res)
print("Inference finished!")


def train_test():
def train_and_test():
# Config Loader
train_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
@@ -67,12 +68,12 @@ def train_test():
train_data = pos_loader.load_lines()

# Preprocessor
p = POSPreprocess(train_data, pickle_path)
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.5)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes

# Trainer
trainer = POSTrainer(train_args)
trainer = SeqLabelTrainer(train_args)

# Model
model = SeqLabeling(train_args)
@@ -100,7 +101,7 @@ def train_test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)
@@ -111,5 +112,5 @@ def train_test():


if __name__ == "__main__":
train_test()
# infer()
# train_and_test()
infer()

+ 5
- 5
test/test_cws.py View File

@@ -3,12 +3,12 @@ import sys
sys.path.append("..")

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import POSTrainer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference

@@ -73,7 +73,7 @@ def train_test():
train_args["num_classes"] = p.num_classes

# Trainer
trainer = POSTrainer(train_args)
trainer = SeqLabelTrainer(train_args)

# Model
model = SeqLabeling(train_args)
@@ -101,7 +101,7 @@ def train_test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)
@@ -113,4 +113,4 @@ def train_test():

if __name__ == "__main__":
train_test()
#infer()
infer()

+ 2
- 2
test/test_tester.py View File

@@ -1,4 +1,4 @@
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
@@ -26,7 +26,7 @@ def foo():
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
"use_cuda": True}
validator = POSTester(valid_args)
validator = SeqLabelTester(valid_args)

print("start validation.")
validator.test(model)


+ 84
- 0
test/text_classify.py View File

@@ -0,0 +1,84 @@
# Python: 3.5
# encoding: utf-8

import os

from fastNLP.core.inference import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.saver.model_saver import ModelSaver

data_dir = "./data_for_tests/"
train_file = 'text_classify.txt'
model_name = "model_class.pkl"


def infer():
# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
data = ds_loader.load()
unlabeled_data = [x[0] for x in data]

# pre-process data
pre = ClassPreprocess(data_dir)
vocab_size, n_classes = pre.process(data, "data_train.pkl")
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
# Dump trained parameters into the model
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
print("model loaded!")

infer = ClassificationInfer(data_dir)
results = infer.predict(cnn, unlabeled_data)
print(results)


def train():
# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
data = ds_loader.load()
print(data[0])

# pre-process data
pre = ClassPreprocess(data_dir)
vocab_size, n_classes = pre.process(data, "data_train.pkl")
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)

# train
print("Training...")
train_args = {
"epochs": 1,
"batch_size": 10,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": "./data_for_tests/",
"use_cuda": True
}
trainer = ClassificationTrainer(train_args)
trainer.train(cnn)

print("Training finished!")

saver = ModelSaver("./data_for_tests/saved_model.pkl")
saver.save_pytorch(cnn)
print("Model saved!")


if __name__ == "__main__":
# train()
infer()

Loading…
Cancel
Save