Browse Source

classification

tags/v0.1.0
Ke Zhen 6 years ago
parent
commit
4a25bdba9c
6 changed files with 640 additions and 10 deletions
  1. +151
    -0
      fastNLP/action/tester.py
  2. +199
    -0
      fastNLP/action/trainer.py
  3. +14
    -9
      fastNLP/loader/dataset_loader.py
  4. +186
    -1
      fastNLP/loader/preprocess.py
  5. +37
    -0
      fastNLP/models/cnn_text_classification.py
  6. +53
    -0
      fastNLP/modules/encoder/conv_maxpool.py

+ 151
- 0
fastNLP/action/tester.py View File

@@ -2,6 +2,7 @@ import _pickle


import numpy as np import numpy as np
import torch import torch
import os


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
@@ -174,3 +175,153 @@ class POSTester(BaseTester):
""" """
loss, accuracy = self.matrices() loss, accuracy = self.matrices()
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)


class ClassTester(BaseTester):
"""Tester for classification."""

def __init__(self, test_args):
"""
:param test_args: a dict-like object that has __getitem__ method, \
can be accessed by "test_args["key_str"]"
"""
# super(ClassTester, self).__init__()
self.pickle_path = test_args["pickle_path"]

self.save_dev_data = None
self.output = None
self.mean_loss = None
self.iterator = None

if "test_name" in test_args:
self.test_name = test_args["test_name"]
else:
self.test_name = "data_test.pkl"

if "validate_in_training" in test_args:
self.validate_in_training = test_args["validate_in_training"]
else:
self.validate_in_training = False

if "save_output" in test_args:
self.save_output = test_args["save_output"]
else:
self.save_output = False

if "save_loss" in test_args:
self.save_loss = test_args["save_loss"]
else:
self.save_loss = True

if "batch_size" in test_args:
self.batch_size = test_args["batch_size"]
else:
self.batch_size = 50
if "use_cuda" in test_args:
self.use_cuda = test_args["use_cuda"]
else:
self.use_cuda = True

if "max_len" in test_args:
self.max_len = test_args["max_len"]
else:
self.max_len = None

self.model = None
self.eval_history = []
self.batch_output = []

def test(self, network):
# prepare model
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network

# no backward setting for model
for param in self.model.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)

# prepare test data
data_test = self.prepare_input(self.pickle_path, self.test_name)

# data generator
self.iterator = iter(Batchifier(
RandomSampler(data_test), self.batch_size, drop_last=False))

# test
n_batches = len(data_test) // self.batch_size
n_print = n_batches // 10
step = 0
for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len):
prediction = self.data_forward(network, batch_x)
eval_results = self.evaluate(prediction, batch_y)

if self.save_output:
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)

if step % n_print == 0:
print("step: {:>5}".format(step))

step += 1

def prepare_input(self, data_dir, file_name):
"""Prepare data."""
file_path = os.path.join(data_dir, file_name)
with open(file_path, 'rb') as f:
data = _pickle.load(f)
return data

def batchify(self, data, max_len=None):
"""Batch and pad data."""
for indices in self.iterator:
# generate batch and pad
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

# convert to tensor
batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

# trim data to max_len
if max_len is not None and batch_x.size(1) > max_len:
batch_x = batch_x[:, :max_len]

yield batch_x, batch_y

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def evaluate(self, y_logit, y_true):
"""Return y_pred and y_true."""
y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
return [y_prob, y_true]

def matrices(self):
"""Compute accuracy."""
y_prob, y_true = zip(*self.eval_history)
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(y_true, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc

def mode(self, model, test=True):
"""To do: combine this function with Trainer ?? """
if test:
model.eval()
else:
model.train()
self.eval_history.clear()

+ 199
- 0
fastNLP/action/trainer.py View File

@@ -2,6 +2,10 @@ import _pickle


import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import os
from time import time
from datetime import timedelta


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
@@ -392,6 +396,201 @@ class POSTrainer(BaseTrainer):
return loss return loss




class ClassTrainer(BaseTrainer):
"""Trainer for classification."""

def __init__(self, train_args):
# super(ClassTrainer, self).__init__(train_args)
self.n_epochs = train_args["epochs"]
self.batch_size = train_args["batch_size"]
self.pickle_path = train_args["pickle_path"]

if "validate" in train_args:
self.validate = train_args["validate"]
else:
self.validate = False
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
self.learn_rate = 1e-3
if "momentum" in train_args:
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
if "use_cuda" in train_args:
self.use_cuda = train_args["use_cuda"]
else:
self.use_cuda = True

self.model = None
self.iterator = None
self.loss_func = None
self.optimizer = None

def train(self, network):
"""General Training Steps
:param network: a model

The method is framework independent.
Work by calling the following methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- get_loss
- grad_backward
- update
Subclasses must implement these methods with a specific framework.
"""
# prepare model and data, transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda:
self.model = network.cuda()
else:
self.model = network
data_train, data_dev, data_test, embedding = self.prepare_input(
self.pickle_path)

# define tester over dev data
# valid_args = {
# "save_output": True, "validate_in_training": True,
# "save_dev_input": True, "save_loss": True,
# "batch_size": self.batch_size, "pickle_path": self.pickle_path}
# validator = POSTester(valid_args)

# urn on network training mode, define loss and optimizer
self.define_loss()
self.define_optimizer()
self.mode(test=False)

# main training epochs
start = time()
n_samples = len(data_train)
n_batches = n_samples // self.batch_size
n_print = n_batches // 10
for epoch in range(self.n_epochs):
# prepare batch iterator
self.iterator = iter(Batchifier(
RandomSampler(data_train), self.batch_size, drop_last=False))

# training iterations in one epoch
step = 0
for batch_x, batch_y in self.batchify(data_train):
prediction = self.data_forward(network, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()

if step % n_print == 0:
acc = self.get_acc(prediction, batch_y)
end = time()
diff = timedelta(seconds=round(end - start))
print("epoch: {:>3} step: {:>4} loss: {:>4.2}"
" train acc: {:>5.1%} time: {}".format(
epoch, step, loss, acc, diff))

step += 1

# if self.validate:
# if data_dev is None:
# raise RuntimeError("No validation data provided.")
# validator.test(network)
# print("[epoch {}]".format(epoch), end=" ")
# print(validator.show_matrices())

# finish training

def prepare_input(self, data_path):
"""
To do: Load pkl files of train/dev/test and embedding
"""

names = [
"data_train.pkl", "data_dev.pkl",
"data_test.pkl", "embedding.pkl"]

files = []
for name in names:
file_path = os.path.join(data_path, name)
if os.path.exists(file_path):
with open(file_path, 'rb') as f:
data = _pickle.load(f)
else:
data = []
files.append(data)

return tuple(files)

def mode(self, test=False):
"""
Tell the network to be trained or not.
:param test: bool
"""
if test:
self.model.eval()
else:
self.model.train()

def define_loss(self):
"""
Assign an instance of loss function to self.loss_func
E.g. self.loss_func = nn.CrossEntropyLoss()
"""
if self.loss_func is None:
if hasattr(self.model, "loss"):
self.loss_func = self.model.loss
else:
self.loss_func = nn.CrossEntropyLoss()

def define_optimizer(self):
"""
Define framework-specific optimizer specified by the models.
"""
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=self.learn_rate,
momentum=self.momentum)

def data_forward(self, network, x):
"""Forward through network."""
logits = network(x)
return logits

def get_loss(self, predict, truth):
"""Calculate loss."""
return self.loss_func(predict, truth)

def grad_backward(self, loss):
"""Compute gradient backward."""
self.model.zero_grad()
loss.backward()

def update(self):
"""Apply gradient."""
self.optimizer.step()

def batchify(self, data):
"""Batch and pad data."""
for indices in self.iterator:
batch = [data[idx] for idx in indices]
batch_x = [sample[0] for sample in batch]
batch_y = [sample[1] for sample in batch]
batch_x = self.pad(batch_x)

batch_x = torch.tensor(batch_x, dtype=torch.long)
batch_y = torch.tensor(batch_y, dtype=torch.long)
if torch.cuda.is_available() and self.use_cuda:
batch_x = batch_x.cuda()
batch_y = batch_y.cuda()

yield batch_x, batch_y

def get_acc(self, y_logit, y_true):
"""Compute accuracy."""
y_pred = torch.argmax(y_logit, dim=-1)
return int(torch.sum(y_true == y_pred)) / len(y_true)


if __name__ == "__name__": if __name__ == "__name__":
train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"}
trainer = BaseTrainer(train_args) trainer = BaseTrainer(train_args)


+ 14
- 9
fastNLP/loader/dataset_loader.py View File

@@ -29,11 +29,11 @@ class POSDatasetLoader(DatasetLoader):
return lines return lines




class ClassificationDatasetLoader(DatasetLoader):
"""loader for classfication data sets"""
class ClassDatasetLoader(DatasetLoader):
"""Loader for classification data sets"""


def __init__(self, data_name, data_path): def __init__(self, data_name, data_path):
super(ClassificationDatasetLoader, data_name).__init__()
super(ClassDatasetLoader, self).__init__(data_name, data_path)


def load(self): def load(self):
assert os.path.exists(self.data_path) assert os.path.exists(self.data_path)
@@ -44,16 +44,21 @@ class ClassificationDatasetLoader(DatasetLoader):
@staticmethod @staticmethod
def parse(lines): def parse(lines):
""" """
:param lines: lines from dataset
:return: list(list(list())): the three level of lists are
Params
lines: lines from dataset
Return
list(list(list())): the three level of lists are
words, sentence, and dataset words, sentence, and dataset
""" """
dataset = list() dataset = list()
for line in lines: for line in lines:
label = line.split(" ")[0]
words = line.split(" ")[1:]
word = list([w for w in words])
sentence = list([word, label])
line = line.strip().split()
label = line[0]
words = line[1:]
if len(words) <= 1:
continue

sentence = [words, label]
dataset.append(sentence) dataset.append(sentence)
return dataset return dataset




+ 186
- 1
fastNLP/loader/preprocess.py View File

@@ -198,4 +198,189 @@ class POSPreprocess(BasePreprocess):
pass pass
def data_test(self): def data_test(self):
pass
pass
class ClassPreprocess(BasePreprocess):
"""
Pre-process the classification datasets.
Params:
pickle_path - directory to save result of pre-processing
Saves:
word2id.pkl
id2word.pkl
class2id.pkl
id2class.pkl
embedding.pkl
data_train.pkl
data_dev.pkl
data_test.pkl
"""
def __init__(self, pickle_path):
# super(ClassPreprocess, self).__init__(data, pickle_path)
self.word_dict = None
self.label_dict = None
self.pickle_path = pickle_path # save directory
def process(self, data, save_name):
"""
Process data.
Params:
data - nested list, data = [sample1, sample2, ...],
sample = [sentence, label], sentence = [word1, word2, ...]
save_name - name of processed data, such as data_train.pkl
Returns:
vocab_size - vocabulary size
n_classes - number of classes
"""
self.build_dict(data)
self.word2id()
vocab_size = self.id2word()
self.class2id()
num_classes = self.id2class()
self.embedding()
self.data_generate(data, save_name)
return vocab_size, num_classes
def build_dict(self, data):
"""Build vocabulary."""
# just read if word2id.pkl and class2id.pkl exists
if self.pickle_exist("word2id.pkl") and \
self.pickle_exist("class2id.pkl"):
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, 'rb') as f:
self.word_dict = _pickle.load(f)
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, 'rb') as f:
self.label_dict = _pickle.load(f)
return
# build vocabulary from scratch if nothing exists
self.word_dict = {
DEFAULT_PADDING_LABEL: 0,
DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2,
DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
self.label_dict = {}
# collect every word and label
for sent, label in data:
if len(sent) <= 1:
continue
if label not in self.label_dict:
index = len(self.label_dict)
self.label_dict[label] = index
for word in sent:
if word not in self.word_dict:
index = len(self.word_dict)
self.word_dict[word[0]] = index
def pickle_exist(self, pickle_name):
"""
Check whether a pickle file exists.
Params
pickle_name: the filename of target pickle file
Return
True if file exists else False
"""
if not os.path.exists(self.pickle_path):
os.makedirs(self.pickle_path)
file_name = os.path.join(self.pickle_path, pickle_name)
if os.path.exists(file_name):
return True
else:
return False
def word2id(self):
"""Save vocabulary of {word:id} mapping format."""
# nothing will be done if word2id.pkl exists
if self.pickle_exist("word2id.pkl"):
return
file_name = os.path.join(self.pickle_path, "word2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.word_dict, f)
def id2word(self):
"""Save vocabulary of {id:word} mapping format."""
# nothing will be done if id2word.pkl exists
if self.pickle_exist("id2word.pkl"):
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, 'rb') as f:
id2word_dict = _pickle.load(f)
return len(id2word_dict)
id2word_dict = {self.word_dict[w]: w for w in self.word_dict}
file_name = os.path.join(self.pickle_path, "id2word.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2word_dict, f)
return len(id2word_dict)
def class2id(self):
"""Save mapping of {class:id}."""
# nothing will be done if class2id.pkl exists
if self.pickle_exist("class2id.pkl"):
return
file_name = os.path.join(self.pickle_path, "class2id.pkl")
with open(file_name, "wb") as f:
_pickle.dump(self.label_dict, f)
def id2class(self):
"""Save mapping of {id:class}."""
# nothing will be done if id2class.pkl exists
if self.pickle_exist("id2class.pkl"):
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "rb") as f:
id2class_dict = _pickle.load(f)
return len(id2class_dict)
id2class_dict = {self.label_dict[c]: c for c in self.label_dict}
file_name = os.path.join(self.pickle_path, "id2class.pkl")
with open(file_name, "wb") as f:
_pickle.dump(id2class_dict, f)
return len(id2class_dict)
def embedding(self):
"""Save embedding lookup table corresponding to vocabulary."""
# nothing will be done if embedding.pkl exists
if self.pickle_exist("embedding.pkl"):
return
# retrieve vocabulary from pre-trained embedding (not implemented)
def data_generate(self, data_src, save_name):
"""Convert dataset from text to digit."""
# nothing will be done if file exists
save_path = os.path.join(self.pickle_path, save_name)
if os.path.exists(save_path):
return
data = []
# for every sample
for sent, label in data_src:
if len(sent) <= 1:
continue
label_id = self.label_dict[label] # label id
sent_id = [] # sentence ids
for word in sent:
if word in self.word_dict:
sent_id.append(self.word_dict[word])
else:
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL])
data.append([sent_id, label_id])
# save data
with open(save_path, "wb") as f:
_pickle.dump(data, f)

+ 37
- 0
fastNLP/models/cnn_text_classification.py View File

@@ -0,0 +1,37 @@
# python: 3.6
# encoding: utf-8

import torch.nn as nn
# import torch.nn.functional as F
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool


class CNNText(BaseModel):
"""
Text classification model by character CNN, the implementation of paper
'Yoon Kim. 2014. Convolution Neural Networks for Sentence
Classification.'
"""

def __init__(self, class_num=9,
kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5],
embed_num=1000, embed_dim=300, pretrained_embed=None,
drop_prob=0.5):
super(CNNText, self).__init__()

# no support for pre-trained embedding currently
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0)
self.conv_pool = ConvMaxpool(
in_channels=embed_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(drop_prob)
self.fc = nn.Linear(sum(kernel_nums), class_num)

def forward(self, x):
x = self.embed(x) # [N,L] -> [N,L,C]
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return x

+ 53
- 0
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -0,0 +1,53 @@
# python: 3.6
# encoding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvMaxpool(nn.Module):
"""
Convolution and max-pooling module with multiple kernel sizes.
"""

def __init__(self, in_channels, out_channels, kernel_sizes,
stride=1, padding=0, dilation=1,
groups=1, bias=True, activation='relu'):
super(ConvMaxpool, self).__init__()

# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
if isinstance(kernel_sizes, int):
out_channels = [out_channels]
kernel_sizes = [kernel_sizes]
self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
out_channels=oc,
kernel_size=ks,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
for oc, ks in zip(out_channels, kernel_sizes)])
else:
raise Exception(
'Incorrect kernel sizes: should be list, tuple or int')

# activation function
if activation == 'relu':
self.activation = F.relu
else:
raise Exception(
"Undefined activation function: choose from: relu")

def forward(self, x):
# [N,L,C] -> [N,C,L]
x = torch.transpose(x, 1, 2)
# convolution
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]]
# max-pooling
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
for i in xs] # [[N, C]]
return torch.cat(xs, dim=-1) # [N,C]

Loading…
Cancel
Save