Browse Source

changes to Trainer, Tester & Inference:

- rename "POSTrainer", "POSTester" to "SeqLabelTrainer", "SeqLabelTester"
- Trainer & Tester have NO relation with Action
- Inference owns independent "make_batch" & "data_forward"
- Conversion to Tensor & go into cuda are done in "make_batch"
- "make_batch" support maximum/minimum length
tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
8e6db05339
10 changed files with 243 additions and 264 deletions
  1. +34
    -5
      fastNLP/core/action.py
  2. +87
    -42
      fastNLP/core/inference.py
  3. +28
    -129
      fastNLP/core/tester.py
  4. +36
    -71
      fastNLP/core/trainer.py
  5. +3
    -2
      fastNLP/models/cnn_text_classification.py
  6. +2
    -2
      reproduction/chinese_word_seg/cws_train.py
  7. +3
    -3
      test/seq_labeling.py
  8. +2
    -2
      test/test_cws.py
  9. +2
    -2
      test/test_tester.py
  10. +46
    -6
      test/text_classify.py

+ 34
- 5
fastNLP/core/action.py View File

@@ -5,6 +5,7 @@
from collections import Counter

import numpy as np
import torch


class Action(object):
@@ -21,7 +22,7 @@ class Action(object):
super(Action, self).__init__()

@staticmethod
def make_batch(iterator, data, output_length=True):
def make_batch(iterator, data, use_cuda, output_length=True, max_len=None):
"""Batch and Pad data.
:param iterator: an iterator, (object that implements __next__ method) which returns the next sample.
:param data: list. Each entry is a sample, which is also a list of features and label(s).
@@ -31,7 +32,9 @@ class Action(object):
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:param use_cuda: bool
:param output_length: whether to output the original length of the sequence before padding.
:param max_len: int, maximum sequence length
:return (batch_x, seq_len): tuple of two elements, if output_length is true.
batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len]
seq_len: list. The length of the pre-padded sequence, if output_length is True.
@@ -43,13 +46,25 @@ class Action(object):
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x_pad = Action.pad(batch_x)
batch_y_pad = Action.pad(batch_y)

batch_x = Action.pad(batch_x)
# pad batch_y only if it is a 2-level list
if len(batch_y) > 0 and isinstance(batch_y[0], list):
batch_y = Action.pad(batch_y)

# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)
batch_y = convert_to_torch_tensor(batch_y, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

if output_length:
seq_len = [len(x) for x in batch_x]
yield (batch_x_pad, seq_len), batch_y_pad
yield (batch_x, seq_len), batch_y
else:
yield batch_x_pad, batch_y_pad
yield batch_x, batch_y

@staticmethod
def pad(batch, fill=0):
@@ -78,6 +93,20 @@ class Action(object):
model.train()


def convert_to_torch_tensor(data_list, use_cuda):
"""
convert lists into (cuda) Tensors
:param data_list: 2-level lists
:param use_cuda: bool
:param reqired_grad: bool
:return: PyTorch Tensor of shape [batch_size, max_seq_len]
"""
data_list = torch.Tensor(data_list).long()
if torch.cuda.is_available() and use_cuda:
data_list = data_list.cuda()
return data_list


def k_means_1d(x, k, max_iter=100):
"""
Perform k-means on 1-D data.


+ 87
- 42
fastNLP/core/inference.py View File

@@ -2,16 +2,53 @@ import numpy as np
import torch

from fastNLP.core.action import Batchifier, SequentialSampler
from fastNLP.core.action import convert_to_torch_tensor
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL
from fastNLP.modules import utils


def make_batch(iterator, data, use_cuda, output_length=False, max_len=None, min_len=None):
for indices in iterator:
batch_x = [data[idx] for idx in indices]
batch_x = pad(batch_x)
# convert list to tensor
batch_x = convert_to_torch_tensor(batch_x, use_cuda)

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]
if min_len is not None and batch_x.size(1) < min_len:
pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x)
batch_x = torch.cat((batch_x, pad_tensor), 1)

if output_length:
seq_len = [len(x) for x in batch_x]
yield tuple([batch_x, seq_len])
else:
yield batch_x


def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch


class Inference(object):
"""
This is an interface focusing on predicting output based on trained models.
It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP.
This class does not share any operations with Trainer and Tester.
Currently, Inference does not support GPU.
"""

def __init__(self, pickle_path):
@@ -38,10 +75,7 @@ class Inference(object):

iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False))

num_iter = len(data) // self.batch_size

for step in range(num_iter):
batch_x = self.make_batch(iterator, data)
for batch_x in self.make_batch(iterator, data, use_cuda=False):

prediction = self.data_forward(network, batch_x)

@@ -54,35 +88,12 @@ class Inference(object):
network.eval()
else:
network.train()
self.batch_output.clear()

def data_forward(self, network, x):
raise NotImplementedError

@staticmethod
def make_batch(iterator, data, output_length=True):
indices = next(iterator)
batch_x = [data[idx] for idx in indices]
batch_x_pad = Inference.pad(batch_x)
if output_length:
seq_len = [len(x) for x in batch_x]
return [batch_x_pad, seq_len]
else:
return batch_x_pad

@staticmethod
def pad(batch, fill=0):
"""
Pad a batch of samples to maximum length.
:param batch: list of list
:param fill: word index to pad, default 0.
:return: a padded batch
"""
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch
def make_batch(self, iterator, data, use_cuda):
raise NotImplementedError

def prepare_input(self, data):
"""
@@ -101,17 +112,8 @@ class Inference(object):
data_index.append([self.word2index.get(w, default_unknown_index) for w in example])
return data_index

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return:
"""
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example])
return results
def prepare_output(self, data):
raise NotImplementedError


class SeqLabelInfer(Inference):
@@ -133,10 +135,53 @@ class SeqLabelInfer(Inference):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
x = torch.Tensor(x).long()
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction)
return torch.Tensor(prediction, required_grad=False)

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=True)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return:
"""
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([self.index2label[int(x)] for x in example])
return results


class ClassificationInfer(Inference):
"""
Inference on Classification models.
"""

def __init__(self, pickle_path):
super(ClassificationInfer, self).__init__(pickle_path)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=False, min_len=5)

def prepare_output(self, batch_outputs):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return:
"""
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy())
results.append(self.index2label[idx])
return results

+ 28
- 129
fastNLP/core/tester.py View File

@@ -1,5 +1,4 @@
import _pickle
import os

import numpy as np
import torch
@@ -9,15 +8,14 @@ from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.modules import utils


class BaseTester(Action):
class BaseTester(object):
"""docstring for Tester"""

def __init__(self, test_args, action=None):
def __init__(self, test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
super(BaseTester, self).__init__()
self.action = action if action is not None else Action()
self.validate_in_training = test_args["validate_in_training"]
self.save_dev_data = None
self.save_output = test_args["save_output"]
@@ -39,16 +37,23 @@ class BaseTester(Action):
else:
self.model = network

# no backward setting for model
for param in network.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.action.mode(network, test=True)
self.mode(network, test=True)
self.eval_history.clear()
self.batch_output.clear()

dev_data = self.prepare_input(self.pickle_path)

iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True))
n_batches = len(dev_data) // self.batch_size
n_print = 1
step = 0

for batch_x, batch_y in self.action.make_batch(iterator, dev_data):
for batch_x, batch_y in self.make_batch(iterator, dev_data):

prediction = self.data_forward(network, batch_x)

@@ -58,6 +63,7 @@ class BaseTester(Action):
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)
step += 1

def prepare_input(self, data_path):
"""
@@ -70,6 +76,9 @@ class BaseTester(Action):
self.save_dev_data = data_dev
return self.save_dev_data

def mode(self, model, test):
Action.mode(model, test)

def data_forward(self, network, x):
raise NotImplementedError

@@ -87,17 +96,20 @@ class BaseTester(Action):
"""
raise NotImplementedError

def make_batch(self, iterator, data):
raise NotImplementedError


class POSTester(BaseTester):
class SeqLabelTester(BaseTester):
"""
Tester for sequence labeling.
"""

def __init__(self, test_args, action=None):
def __init__(self, test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]"
"""
super(POSTester, self).__init__(test_args, action)
super(SeqLabelTester, self).__init__(test_args)
self.max_len = None
self.mask = None
self.batch_result = None
@@ -107,13 +119,10 @@ class POSTester(BaseTester):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
x = torch.Tensor(x).long()
batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)

if torch.cuda.is_available() and self.use_cuda:
x = x.cuda()
mask = mask.cuda()
self.mask = mask

@@ -121,9 +130,6 @@ class POSTester(BaseTester):
return y

def evaluate(self, predict, truth):
truth = torch.Tensor(truth)
if torch.cuda.is_available() and self.use_cuda:
truth = truth.cuda()
batch_size, max_len = predict.size(0), predict.size(1)
loss = self.model.loss(predict, truth, self.mask) / batch_size

@@ -147,8 +153,11 @@ class POSTester(BaseTester):
loss, accuracy = self.metrics()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)


class ClassTester(BaseTester):
class ClassificationTester(BaseTester):
"""Tester for classification."""

def __init__(self, test_args):
@@ -156,7 +165,7 @@ class ClassTester(BaseTester):
:param test_args: a dict-like object that has __getitem__ method, \
can be accessed by "test_args["key_str"]"
"""
# super(ClassTester, self).__init__()
super(ClassificationTester, self).__init__(test_args)
self.pickle_path = test_args["pickle_path"]

self.save_dev_data = None
@@ -164,111 +173,8 @@ class ClassTester(BaseTester):
self.mean_loss = None
self.iterator = None

if "test_name" in test_args:
self.test_name = test_args["test_name"]
else:
self.test_name = "data_test.pkl"

if "validate_in_training" in test_args:
self.validate_in_training = test_args["validate_in_training"]
else:
self.validate_in_training = False

if "save_output" in test_args:
self.save_output = test_args["save_output"]
else:
self.save_output = False

if "save_loss" in test_args:
self.save_loss = test_args["save_loss"]
else:
self.save_loss = True

if "batch_size" in test_args:
self.batch_size = test_args["batch_size"]
else:
self.batch_size = 50
if "use_cuda" in test_args:
self.use_cuda = test_args["use_cuda"]
else:
self.use_cuda = True

if "max_len" in test_args:
self.max_len = test_args["max_len"]
else:
self.max_len = None

self.model = None
self.eval_history = []
self.batch_output = []

def test(self, network):
# prepare model
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network

# no backward setting for model
for param in self.model.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)

# prepare test data
data_test = self.prepare_input(self.pickle_path, self.test_name)

# data generator
self.iterator = iter(Batchifier(
RandomSampler(data_test), self.batch_size, drop_last=False))

# test
n_batches = len(data_test) // self.batch_size
n_print = n_batches // 10
step = 0
for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len):
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)

if step % n_print == 0:
print("step: {:>5}".format(step))

step += 1

def prepare_input(self, data_dir, file_name):
"""Prepare data."""
file_path = os.path.join(data_dir, file_name)
with open(file_path, 'rb') as f:
data = _pickle.load(f)
return data

def make_batch(self, data, max_len=None):
"""Batch and pad data."""
for indices in self.iterator:
# generate batch and pad
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

# convert to tensor
batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

yield batch_x, batch_y
def make_batch(self, iterator, data, max_len=None):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, max_len=max_len)

def data_forward(self, network, x):
"""Forward through network."""
@@ -289,10 +195,3 @@ class ClassTester(BaseTester):
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def mode(self, model, test=True):
"""TODO: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()

+ 36
- 71
fastNLP/core/trainer.py View File

@@ -9,12 +9,12 @@ import torch.nn as nn

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
from fastNLP.modules import utils
from fastNLP.saver.model_saver import ModelSaver


class BaseTrainer(Action):
class BaseTrainer(object):
"""Base trainer for all trainers.
Trainer receives a model and data, and then performs training.

@@ -24,10 +24,9 @@ class BaseTrainer(Action):
- get_loss
"""

def __init__(self, train_args, action=None):
def __init__(self, train_args):
"""
:param train_args: dict of (key, value), or dict-like object. key is str.
:param action: (optional) an Action object that wrap most operations shared by Trainer, Tester, and Inference.

The base trainer requires the following keys:
- epochs: int, the number of epochs in training
@@ -36,7 +35,6 @@ class BaseTrainer(Action):
- pickle_path: str, the path to pickle files for pre-processing
"""
super(BaseTrainer, self).__init__()
self.action = action if action is not None else Action()
self.n_epochs = train_args["epochs"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]
@@ -79,7 +77,7 @@ class BaseTrainer(Action):
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
validator = POSTester(default_valid_args, self.action)
validator = self._create_validator(default_valid_args)

self.define_optimizer()

@@ -92,12 +90,12 @@ class BaseTrainer(Action):
for epoch in range(1, self.n_epochs + 1):

# turn on network training mode; prepare batch iterator
self.action.mode(network, test=False)
self.mode(network, test=False)
iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
step = 0
for batch_x, batch_y in self.action.make_batch(iterator, data_train, output_length=True):
for batch_x, batch_y in self.make_batch(iterator, data_train):

prediction = self.data_forward(network, batch_x)

@@ -142,6 +140,12 @@ class BaseTrainer(Action):
files.append(data)
return tuple(files)

def make_batch(self, iterator, data):
raise NotImplementedError

def mode(self, network, test):
Action.mode(network, test)

def define_optimizer(self):
"""
Define framework-specific optimizer specified by the models.
@@ -203,6 +207,9 @@ class BaseTrainer(Action):
"""
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network)

def _create_validator(self, valid_args):
raise NotImplementedError


class ToyTrainer(BaseTrainer):
"""
@@ -217,12 +224,6 @@ class ToyTrainer(BaseTrainer):
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1

def mode(self, test=False):
if test:
self.model.eval()
else:
self.model.train()

def data_forward(self, network, x):
return network(x)

@@ -246,8 +247,8 @@ class SeqLabelTrainer(BaseTrainer):

"""

def __init__(self, train_args, action=None):
super(SeqLabelTrainer, self).__init__(train_args, action)
def __init__(self, train_args):
super(SeqLabelTrainer, self).__init__(train_args)
self.vocab_size = train_args["vocab_size"]
self.num_classes = train_args["num_classes"]
self.max_len = None
@@ -269,14 +270,12 @@ class SeqLabelTrainer(BaseTrainer):
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.")
# unpack the returned value from make_batch
x, seq_len = inputs[0], inputs[1]
x = torch.Tensor(x).long()

batch_size, max_len = x.size(0), x.size(1)
mask = utils.seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)

if torch.cuda.is_available() and self.use_cuda:
x = x.cuda()
mask = mask.cuda()
self.mask = mask

@@ -290,9 +289,6 @@ class SeqLabelTrainer(BaseTrainer):
:param truth: ground truth label vector, [batch_size, max_len]
:return: a scalar
"""
truth = torch.Tensor(truth)
if torch.cuda.is_available() and self.use_cuda:
truth = truth.cuda()
batch_size, max_len = predict.size(0), predict.size(1)
assert truth.shape == (batch_size, max_len)

@@ -307,32 +303,18 @@ class SeqLabelTrainer(BaseTrainer):
else:
return False

def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=True, use_cuda=self.use_cuda)

class LanguageModelTrainer(BaseTrainer):
"""
Trainer for Language Model
"""

def __init__(self, train_args):
super(LanguageModelTrainer, self).__init__(train_args)

def prepare_input(self, data_path):
pass
def _create_validator(self, valid_args):
return SeqLabelTester(valid_args)


class ClassTrainer(BaseTrainer):
class ClassificationTrainer(BaseTrainer):
"""Trainer for classification."""

def __init__(self, train_args, action=None):
super(ClassTrainer, self).__init__(train_args, action)
self.n_epochs = train_args["epochs"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]

if "validate" in train_args:
self.validate = train_args["validate"]
else:
self.validate = False
def __init__(self, train_args):
super(ClassificationTrainer, self).__init__(train_args)
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
@@ -341,15 +323,11 @@ class ClassTrainer(BaseTrainer):
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
if "use_cuda" in train_args:
self.use_cuda = train_args["use_cuda"]
else:
self.use_cuda = True

self.model = None
self.iterator = None
self.loss_func = None
self.optimizer = None
self.best_accuracy = 0

def define_loss(self):
self.loss_func = nn.CrossEntropyLoss()
@@ -365,9 +343,6 @@ class ClassTrainer(BaseTrainer):

def data_forward(self, network, x):
"""Forward through network."""
x = torch.Tensor(x).long()
if torch.cuda.is_available() and self.use_cuda:
x = x.cuda()
logits = network(x)
return logits

@@ -380,31 +355,21 @@ class ClassTrainer(BaseTrainer):
"""Apply gradient."""
self.optimizer.step()

"""
def make_batch(self, data):
for indices in self.iterator:
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

batch_x = torch.Tensor(batch_x).long()
batch_y = torch.Tensor(batch_y).long()
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

yield batch_x, batch_y
"""
def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, output_length=False, use_cuda=self.use_cuda)

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)

def best_eval_result(self, validator):
_, _, accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False

if __name__ == "__name__":
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args)
data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10]
trainer.make_batch(data=data_train)
def _create_validator(self, valid_args):
return ClassificationTester(valid_args)

+ 3
- 2
fastNLP/models/cnn_text_classification.py View File

@@ -1,13 +1,14 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn

# import torch.nn.functional as F
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool


class CNNText(BaseModel):
class CNNText(torch.nn.Module):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence


+ 2
- 2
reproduction/chinese_word_seg/cws_train.py View File

@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference

@@ -96,7 +96,7 @@ def test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)


+ 3
- 3
test/seq_labeling.py View File

@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import SeqLabelInfer

@@ -101,7 +101,7 @@ def train_and_test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)
@@ -112,5 +112,5 @@ def train_and_test():


if __name__ == "__main__":
train_and_test()
# train_and_test()
infer()

+ 2
- 2
test/test_cws.py View File

@@ -8,7 +8,7 @@ from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.inference import Inference

@@ -101,7 +101,7 @@ def train_test():
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})

# Tester
tester = POSTester(test_args)
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)


+ 2
- 2
test/test_tester.py View File

@@ -1,4 +1,4 @@
from fastNLP.core.tester import POSTester
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader
from fastNLP.loader.preprocess import POSPreprocess
@@ -26,7 +26,7 @@ def foo():
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/",
"use_cuda": True}
validator = POSTester(valid_args)
validator = SeqLabelTester(valid_args)

print("start validation.")
validator.test(model)


+ 46
- 6
test/text_classify.py View File

@@ -3,16 +3,45 @@

import os

from fastNLP.core.trainer import ClassTrainer
from fastNLP.core.inference import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.saver.model_saver import ModelSaver

if __name__ == "__main__":
data_dir = "./data_for_tests/"
train_file = 'text_classify.txt'
model_name = "model_class.pkl"
data_dir = "./data_for_tests/"
train_file = 'text_classify.txt'
model_name = "model_class.pkl"


def infer():
# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
data = ds_loader.load()
unlabeled_data = [x[0] for x in data]

# pre-process data
pre = ClassPreprocess(data_dir)
vocab_size, n_classes = pre.process(data, "data_train.pkl")
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
# Dump trained parameters into the model
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
print("model loaded!")

infer = ClassificationInfer(data_dir)
results = infer.predict(cnn, unlabeled_data)
print(results)


def train():
# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
@@ -40,5 +69,16 @@ if __name__ == "__main__":
"model_saved_path": "./data_for_tests/",
"use_cuda": True
}
trainer = ClassTrainer(train_args)
trainer = ClassificationTrainer(train_args)
trainer.train(cnn)

print("Training finished!")

saver = ModelSaver("./data_for_tests/saved_model.pkl")
saver.save_pytorch(cnn)
print("Model saved!")


if __name__ == "__main__":
# train()
infer()

Loading…
Cancel
Save