Browse Source

classification

tags/v0.1.0
Ke Zhen 6 years ago
parent
commit
4a25bdba9c
6 changed files with 640 additions and 10 deletions
  1. +151
    -0
      fastNLP/action/tester.py
  2. +199
    -0
      fastNLP/action/trainer.py
  3. +14
    -9
      fastNLP/loader/dataset_loader.py
  4. +186
    -1
      fastNLP/loader/preprocess.py
  5. +37
    -0
      fastNLP/models/cnn_text_classification.py
  6. +53
    -0
      fastNLP/modules/encoder/conv_maxpool.py

+ 151
- 0
fastNLP/action/tester.py View File

@@ -2,6 +2,7 @@ import _pickle

import numpy as np
import torch
import os

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
@@ -174,3 +175,153 @@ class POSTester(BaseTester):
"""
loss, accuracy = self.matrices()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)


class ClassTester(BaseTester):
"""Tester for classification."""

def __init__(self, test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, \
can be accessed by "test_args["key_str"]"
"""
# super(ClassTester, self).__init__()
self.pickle_path = test_args["pickle_path"]

self.save_dev_data = None
self.output = None
self.mean_loss = None
self.iterator = None

if "test_name" in test_args:
self.test_name = test_args["test_name"]
else:
self.test_name = "data_test.pkl"

if "validate_in_training" in test_args:
self.validate_in_training = test_args["validate_in_training"]
else:
self.validate_in_training = False

if "save_output" in test_args:
self.save_output = test_args["save_output"]
else:
self.save_output = False

if "save_loss" in test_args:
self.save_loss = test_args["save_loss"]
else:
self.save_loss = True

if "batch_size" in test_args:
self.batch_size = test_args["batch_size"]
else:
self.batch_size = 50
if "use_cuda" in test_args:
self.use_cuda = test_args["use_cuda"]
else:
self.use_cuda = True

if "max_len" in test_args:
self.max_len = test_args["max_len"]
else:
self.max_len = None

self.model = None
self.eval_history = []
self.batch_output = []

def test(self, network):
# prepare model
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network

# no backward setting for model
for param in self.model.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)

# prepare test data
data_test = self.prepare_input(self.pickle_path, self.test_name)

# data generator
self.iterator = iter(Batchifier(
RandomSampler(data_test), self.batch_size, drop_last=False))

# test
n_batches = len(data_test) // self.batch_size
n_print = n_batches // 10
step = 0
for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len):
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)

if step % n_print == 0:
print("step: {:>5}".format(step))

step += 1

def prepare_input(self, data_dir, file_name):
"""Prepare data."""
file_path = os.path.join(data_dir, file_name)
with open(file_path, 'rb') as f:
data = _pickle.load(f)
return data

def batchify(self, data, max_len=None):
"""Batch and pad data."""
for indices in self.iterator:
# generate batch and pad
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

# convert to tensor
batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

yield batch_x, batch_y

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def evaluate(self, y_logit, y_true):
"""Return y_pred and y_true."""
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

def matrices(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def mode(self, model, test=True):
"""To do: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()

+ 199
- 0
fastNLP/action/trainer.py View File

@@ -2,6 +2,10 @@ import _pickle

import numpy as np
import torch
import torch.nn as nn
import os
from time import time
from datetime import timedelta

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
@@ -392,6 +396,201 @@ class POSTrainer(BaseTrainer):
return loss


class ClassTrainer(BaseTrainer):
"""Trainer for classification."""

def __init__(self, train_args):
# super(ClassTrainer, self).__init__(train_args)
self.n_epochs = train_args["epochs"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]

if "validate" in train_args:
self.validate = train_args["validate"]
else:
self.validate = False
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
self.learn_rate = 1e-3
if "momentum" in train_args:
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
if "use_cuda" in train_args:
self.use_cuda = train_args["use_cuda"]
else:
self.use_cuda = True

self.model = None
self.iterator = None
self.loss_func = None
self.optimizer = None

def train(self, network):
"""General Training Steps
:param network: a model

The method is framework independent.
Work by calling the following methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- get_loss
- grad_backward
- update
Subclasses must implement these methods with a specific framework.
"""
# prepare model and data, transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network
data_train, data_dev, data_test, embedding = self.prepare_input(
self.pickle_path)

# define tester over dev data
# valid_args = {
# "save_output": True, "validate_in_training": True,
# "save_dev_input": True, "save_loss": True,
# "batch_size": self.batch_size, "pickle_path": self.pickle_path}
# validator = POSTester(valid_args)

# urn on network training mode, define loss and optimizer
self.define_loss()
self.define_optimizer()
self.mode(test=False)

# main training epochs
start = time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = n_batches // 10
for epoch in range(self.n_epochs):
# prepare batch iterator
self.iterator = iter(Batchifier(
RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
step = 0
for batch_x, batch_y in self.batchify(data_train):
prediction = self.data_forward(network, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

if step % n_print == 0:
acc = self.get_acc(prediction, batch_y)
end = time()
diff = timedelta(seconds=round(end - start))
print("epoch: {:>3} step: {:>4} loss: {:>4.2}"
" train acc: {:>5.1%} time: {}".format(
epoch, step, loss, acc, diff))

step += 1

# if self.validate:
# if data_dev is None:
# raise RuntimeError("No validation data provided.")
# validator.test(network)
# print("[epoch {}]".format(epoch), end=" ")
# print(validator.show_matrices())

# finish training

def prepare_input(self, data_path):
"""
To do: Load pkl files of train/dev/test and embedding
"""

names = [
"data_train.pkl", "data_dev.pkl",
"data_test.pkl", "embedding.pkl"]

files = []
for name in names:
file_path = os.path.join(data_path, name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
data = []
files.append(data)

return tuple(files)

def mode(self, test=False):
"""
Tell the network to be trained or not.
:param test: bool
"""
if test:
self.model.eval()
else:
self.model.train()

def define_loss(self):
"""
Assign an instance of loss function to self.loss_func
E.g. self.loss_func = nn.CrossEntropyLoss()
"""
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.loss_func = nn.CrossEntropyLoss()

def define_optimizer(self):
"""
Define framework-specific optimizer specified by the models.
"""
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.learn_rate,
momentum=self.momentum)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def get_loss(self, predict, truth):
"""Calculate loss."""
return self.loss_func(predict, truth)

def grad_backward(self, loss):
"""Compute gradient backward."""
self.model.zero_grad()
loss.backward()

def update(self):
"""Apply gradient."""
self.optimizer.step()

def batchify(self, data):
"""Batch and pad data."""
for indices in self.iterator:
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

yield batch_x, batch_y

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)


if __name__ == "__name__":
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args)


+ 14
- 9
fastNLP/loader/dataset_loader.py View File

@@ -29,11 +29,11 @@ class POSDatasetLoader(DatasetLoader):
return lines


class ClassificationDatasetLoader(DatasetLoader):
"""loader for classfication data sets"""
class ClassDatasetLoader(DatasetLoader):
"""Loader for classification data sets"""

def __init__(self, data_name, data_path):
super(ClassificationDatasetLoader, data_name).__init__()
super(ClassDatasetLoader, self).__init__(data_name, data_path)

def load(self):
assert os.path.exists(self.data_path)
@@ -44,16 +44,21 @@ class ClassificationDatasetLoader(DatasetLoader):
@staticmethod
def parse(lines):
"""
:param lines: lines from dataset
:return: list(list(list())): the three level of lists are
Params
lines: lines from dataset
Return
list(list(list())): the three level of lists are
words, sentence, and dataset
"""
dataset = list()
for line in lines:
label = line.split(" ")[0]
words = line.split(" ")[1:]
word = list([w for w in words])
sentence = list([word, label])
line = line.strip().split()
label = line[0]
words = line[1:]
if len(words) <= 1:
continue

sentence = [words, label]
dataset.append(sentence)
return dataset



+ 186
- 1
fastNLP/loader/preprocess.py View File

@@ -198,4 +198,189 @@ class POSPreprocess(BasePreprocess):
pass
def data_test(self):
pass
pass
class ClassPreprocess(BasePreprocess):
"""
Pre-process the classification datasets.
Params:
pickle_path - directory to save result of pre-processing
Saves:
word2id.pkl
id2word.pkl
class2id.pkl
id2class.pkl
embedding.pkl
data_train.pkl
data_dev.pkl
data_test.pkl
"""
def __init__(self, pickle_path):
# super(ClassPreprocess, self).__init__(data, pickle_path)
self.word_dict = None
self.label_dict = None
self.pickle_path = pickle_path # save directory
def process(self, data, save_name):
"""
Process data.
Params:
data - nested list, data = [sample1, sample2, ...],
sample = [sentence, label], sentence = [word1, word2, ...]
save_name - name of processed data, such as data_train.pkl
Returns:
vocab_size - vocabulary size
n_classes - number of classes
"""
self.build_dict(data)
self.word2id()
vocab_size = self.id2word()
self.class2id()
num_classes = self.id2class()
self.embedding()
self.data_generate(data, save_name)
return vocab_size, num_classes
def build_dict(self, data):
"""Build vocabulary."""
# just read if word2id.pkl and class2id.pkl exists
if self.pickle_exist("word2id.pkl") and \
self.pickle_exist("class2id.pkl"):
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, 'rb') as f:
self.word_dict = _pickle.load(f)
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, 'rb') as f:
self.label_dict = _pickle.load(f)
return
# build vocabulary from scratch if nothing exists
self.word_dict = {
DEFAULT_PADDING_LABEL: 0,
DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2,
DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = {}
# collect every word and label
for sent, label in data:
if len(sent) <= 1:
continue
if label not in self.label_dict:
index = len(self.label_dict)
self.label_dict[label] = index
for word in sent:
if word not in self.word_dict:
index = len(self.word_dict)
self.word_dict[word[0]] = index
def pickle_exist(self, pickle_name):
"""
Check whether a pickle file exists.
Params
pickle_name: the filename of target pickle file
Return
True if file exists else False
"""
if not os.path.exists(self.pickle_path):
os.makedirs(self.pickle_path)
file_name = os.path.join(self.pickle_path, pickle_name)
if os.path.exists(file_name):
return True
else:
return False
def word2id(self):
"""Save vocabulary of {word:id} mapping format."""
# nothing will be done if word2id.pkl exists
if self.pickle_exist("word2id.pkl"):
return
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
def id2word(self):
"""Save vocabulary of {id:word} mapping format."""
# nothing will be done if id2word.pkl exists
if self.pickle_exist("id2word.pkl"):
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, 'rb') as f:
id2word_dict = _pickle.load(f)
return len(id2word_dict)
id2word_dict = {self.word_dict[w]: w for w in self.word_dict}
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2word_dict, f)
return len(id2word_dict)
def class2id(self):
"""Save mapping of {class:id}."""
# nothing will be done if class2id.pkl exists
if self.pickle_exist("class2id.pkl"):
return
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.label_dict, f)
def id2class(self):
"""Save mapping of {id:class}."""
# nothing will be done if id2class.pkl exists
if self.pickle_exist("id2class.pkl"):
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "rb") as f:
id2class_dict = _pickle.load(f)
return len(id2class_dict)
id2class_dict = {self.label_dict[c]: c for c in self.label_dict}
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2class_dict, f)
return len(id2class_dict)
def embedding(self):
"""Save embedding lookup table corresponding to vocabulary."""
# nothing will be done if embedding.pkl exists
if self.pickle_exist("embedding.pkl"):
return
# retrieve vocabulary from pre-trained embedding (not implemented)
def data_generate(self, data_src, save_name):
"""Convert dataset from text to digit."""
# nothing will be done if file exists
save_path = os.path.join(self.pickle_path, save_name)
if os.path.exists(save_path):
return
data = []
# for every sample
for sent, label in data_src:
if len(sent) <= 1:
continue
label_id = self.label_dict[label] # label id
sent_id = [] # sentence ids
for word in sent:
if word in self.word_dict:
sent_id.append(self.word_dict[word])
else:
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL])
data.append([sent_id, label_id])
# save data
with open(save_path, "wb") as f:
_pickle.dump(data, f)

+ 37
- 0
fastNLP/models/cnn_text_classification.py View File

@@ -0,0 +1,37 @@
# python: 3.6
# encoding: utf-8

import torch.nn as nn
# import torch.nn.functional as F
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool


class CNNText(BaseModel):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence
Classification.'
"""

def __init__(self, class_num=9,
kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5],
embed_num=1000, embed_dim=300, pretrained_embed=None,
drop_prob=0.5):
super(CNNText, self).__init__()

# no support for pre-trained embedding currently
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0)
self.conv_pool = ConvMaxpool(
in_channels=embed_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(drop_prob)
self.fc = nn.Linear(sum(kernel_nums), class_num)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return x

+ 53
- 0
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -0,0 +1,53 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvMaxpool(nn.Module):
"""
Convolution and max-pooling module with multiple kernel sizes.
"""

def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu'):
super(ConvMaxpool, self).__init__()

# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
if isinstance(kernel_sizes, int):
out_channels = [out_channels]
kernel_sizes = [kernel_sizes]
self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
out_channels=oc,
kernel_size=ks,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
for oc, ks in zip(out_channels, kernel_sizes)])
else:
raise Exception(
'Incorrect kernel sizes: should be list, tuple or int')

# activation function
if activation == 'relu':
self.activation = F.relu
else:
raise Exception(
"Undefined activation function: choose from: relu")

def forward(self, x):
# [N,L,C] -> [N,C,L]
x = torch.transpose(x, 1, 2)
# convolution
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]]
# max-pooling
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
for i in xs] # [[N, C]]
return torch.cat(xs, dim=-1) # [N,C]

Loading…
Cancel
Save