Browse Source

* add DataSet.split()

* delete field.py
* remove logger in all codes
* adjust arguments of Trainer
tags/v0.2.0
FengZiYjun yunfan 6 years ago
parent
commit
80884322c2
5 changed files with 135 additions and 322 deletions
  1. +41
    -3
      fastNLP/core/dataset.py
  2. +0
    -89
      fastNLP/core/field.py
  3. +2
    -2
      fastNLP/core/tester.py
  4. +86
    -221
      fastNLP/core/trainer.py
  5. +6
    -7
      fastNLP/io/config_saver.py

+ 41
- 3
fastNLP/core/dataset.py View File

@@ -1,3 +1,5 @@
import numpy as np

from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray


_READERS = {} _READERS = {}
@@ -6,7 +8,7 @@ _READERS = {}
def construct_dataset(sentences): def construct_dataset(sentences):
"""Construct a data set from a list of sentences. """Construct a data set from a list of sentences.


:param sentences: list of str
:param sentences: list of list of str
:return dataset: a DataSet object :return dataset: a DataSet object
""" """
dataset = DataSet() dataset = DataSet()
@@ -18,7 +20,9 @@ def construct_dataset(sentences):




class DataSet(object): class DataSet(object):
"""A DataSet object is a list of Instance objects.
"""DataSet is the collection of examples.
DataSet provides instance-level interface. You can append and access an instance of the DataSet.
However, it stores data in a different way: Field-first, Instance-second.


""" """


@@ -47,6 +51,11 @@ class DataSet(object):
in self.dataset.get_fields().keys()]) in self.dataset.get_fields().keys()])


def __init__(self, data=None): def __init__(self, data=None):
"""

:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field.
If it is a list, it must be a list of Instance objects.
"""
self.field_arrays = {} self.field_arrays = {}
if data is not None: if data is not None:
if isinstance(data, dict): if isinstance(data, dict):
@@ -78,8 +87,14 @@ class DataSet(object):
self.append(ins_list) self.append(ins_list)


def append(self, ins): def append(self, ins):
# no field
"""Add an instance to the DataSet.
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet.

:param ins: an Instance object

"""
if len(self.field_arrays) == 0: if len(self.field_arrays) == 0:
# DataSet has no field yet
for name, field in ins.fields.items(): for name, field in ins.fields.items():
self.field_arrays[name] = FieldArray(name, [field]) self.field_arrays[name] = FieldArray(name, [field])
else: else:
@@ -89,6 +104,15 @@ class DataSet(object):
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False):
"""
:param name:
:param fields:
:param padding_val:
:param need_tensor:
:param is_target:
:return:
"""
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
assert len(self) == len(fields) assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields, self.field_arrays[name] = FieldArray(name, fields,
@@ -210,6 +234,20 @@ class DataSet(object):
else: else:
return results return results


def split(self, test_ratio):
assert isinstance(test_ratio, float)
all_indices = [_ for _ in range(len(self))]
np.random.shuffle(all_indices)
test_indices = all_indices[:int(test_ratio)]
train_indices = all_indices[int(test_ratio):]
test_set = DataSet()
train_set = DataSet()
for idx in test_indices:
test_set.append(self[idx])
for idx in train_indices:
train_set.append(self[idx])
return train_set, test_set



if __name__ == '__main__': if __name__ == '__main__':
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance


+ 0
- 89
fastNLP/core/field.py View File

@@ -1,89 +0,0 @@
import torch


class Field(object):
"""A field defines a data type.

"""

def __init__(self, content, is_target: bool):
self.is_target = is_target
self.content = content

def index(self, vocab):
"""create index field
"""
raise NotImplementedError

def __len__(self):
"""number of samples
"""
assert self.content is not None
return len(self.content)

def to_tensor(self, id_list):
"""convert batch of index to tensor
"""
raise NotImplementedError

def __repr__(self):
return self.content.__repr__()


class TextField(Field):
def __init__(self, text, is_target):
"""
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(text, is_target)


class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.

"""

def __init__(self, label, is_target=True):
super(LabelField, self).__init__(label, is_target)


class SeqLabelField(Field):
def __init__(self, label_seq, is_target=True):
super(SeqLabelField, self).__init__(label_seq, is_target)


class CharTextField(Field):
def __init__(self, text, max_word_len, is_target=False):
super(CharTextField, self).__init__(is_target)
# TODO
raise NotImplementedError
self.max_word_len = max_word_len
self._index = []

def get_length(self):
return len(self.text)

def contents(self):
return self.text.copy()

def index(self, char_vocab):
if len(self._index) == 0:
for word in self.text:
char_index = [char_vocab[ch] for ch in word]
if self.max_word_len >= len(char_index):
char_index += [0] * (self.max_word_len - len(char_index))
else:
self._index.clear()
raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len))
self._index.append(char_index)
return self._index

def to_tensor(self, padding_length):
"""

:param padding_length: int, the padding length of the word sequence.
:return : tensor of shape (padding_length, max_word_len)
"""
pads = [[0] * self.max_word_len] * (padding_length - self.get_length())
return torch.LongTensor(self._index + pads)

+ 2
- 2
fastNLP/core/tester.py View File

@@ -5,9 +5,9 @@ import torch
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.io.logger import create_logger


logger = create_logger(__name__, "./train_test.log")

# logger = create_logger(__name__, "./train_test.log")




class Tester(object): class Tester(object):


+ 86
- 221
fastNLP/core/trainer.py View File

@@ -1,4 +1,3 @@
import os
import time import time
from datetime import timedelta, datetime from datetime import timedelta, datetime


@@ -11,157 +10,76 @@ from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.logger import create_logger
from fastNLP.io.model_saver import ModelSaver

logger = create_logger(__name__, "./train_test.log")
logger.disabled = True




class Trainer(object): class Trainer(object):
"""Operations of training a model, including data loading, gradient descent, and validation.
"""Main Training Loop


""" """


def __init__(self, **kwargs):
"""
:param kwargs: dict of (key, value), or dict-like object. key is str.

The base trainer requires the following keys:
- epochs: int, the number of epochs in training
- validate: bool, whether or not to validate on dev set
- batch_size: int
- pickle_path: str, the path to pickle files for pre-processing
"""
def __init__(self, train_data, model, n_epochs, batch_size, n_print,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
**kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


"""
"default_args" provides default value for important settings.
The initialization arguments "kwargs" with the same key (name) will override the default value.
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"valid_step": 500, "eval_sort_key": 'acc',
"loss": Loss(None), # used to pass type check
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
"eval_batch_size": 64,
"evaluator": Evaluator(),
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
required_args = {}

for req_key in required_args:
if req_key not in kwargs:
logger.error("Trainer lacks argument {}".format(req_key))
raise ValueError("Trainer lacks argument {}".format(req_key))

for key in default_args:
if key in kwargs:
if isinstance(kwargs[key], type(default_args[key])):
default_args[key] = kwargs[key]
else:
msg = "Argument %s type mismatch: expected %s while get %s" % (
key, type(default_args[key]), type(kwargs[key]))
logger.error(msg)
raise ValueError(msg)
else:
# Trainer doesn't care about extra arguments
pass
print("Training Args {}".format(default_args))
logger.info("Training Args {}".format(default_args))

self.n_epochs = int(default_args["epochs"])
self.batch_size = int(default_args["batch_size"])
self.eval_batch_size = int(default_args['eval_batch_size'])
self.pickle_path = default_args["pickle_path"]
self.validate = default_args["validate"]
self.save_best_dev = default_args["save_best_dev"]
self.use_cuda = default_args["use_cuda"]
self.model_name = default_args["model_name"]
self.print_every_step = int(default_args["print_every_step"])
self.valid_step = int(default_args["valid_step"])
if self.validate is not None:
assert self.valid_step > 0

self._model = None
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
self._evaluator = default_args["evaluator"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
self.model = model
self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda)
self.save_path = str(save_path)
self.n_print = int(n_print)

self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get()
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())
self.evaluator = evaluator

if self.dev_data is not None:
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path,
"use_cuda": self.use_cuda, "evaluator": self.evaluator}
self.tester = Tester(**valid_args)

for k, v in kwargs.items():
setattr(self, k, v)

self._summary_writer = SummaryWriter(self.save_path + 'tensorboard_logs')
self._graph_summaried = False self._graph_summaried = False
self._best_accuracy = 0.0
self.eval_sort_key = default_args['eval_sort_key']
self.validator = None
self.epoch = 0
self.step = 0 self.step = 0
self.start_time = None # start timestamp


def train(self, network, train_data, dev_data=None):
"""General Training Procedure
print(self.__dict__)


:param network: a model
:param train_data: a DataSet instance, the training data
:param dev_data: a DataSet instance, the validation data (optional)
def train(self):
"""Start Training.

:return:
""" """
# transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = network.cuda()
# self._model is used to access model-specific loss
else:
self._model = network

print(self._model)

# define Tester over dev data
self.dev_data = None
if self.validate:
default_valid_args = {"batch_size": self.eval_batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda, "evaluator": self._evaluator}
if self.validator is None:
self.validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(self.validator)))
self.dev_data = dev_data

# optimizer and loss
self.define_optimizer()
logger.info("optimizer defined as {}".format(str(self._optimizer)))
self.define_loss()
logger.info("loss function defined as {}".format(str(self._loss_func)))

# turn on network training mode
self.mode(network, is_test=False)

# main training procedure
self.model = self.model.cuda()

self.mode(self.model, is_test=False)

start = time.time() start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time) print("training epochs started " + self.start_time)
logger.info("training epochs started " + self.start_time)
self.epoch, self.step = 1, 0
while self.epoch <= self.n_epochs:
logger.info("training epoch {}".format(self.epoch))

# prepare mini-batch iterator
data_iterator = Batch(train_data, batch_size=self.batch_size,
sampler=BucketSampler(10, self.batch_size, "word_seq_origin_len"),

epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
use_cuda=self.use_cuda) use_cuda=self.use_cuda)
logger.info("prepared data iterator")


# one forward and backward pass
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, dev_data=dev_data)
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print)


# validation
if self.validate:
self.valid_model()
self.save_model(self._model, 'training_model_' + self.start_time)
self.epoch += 1
if self.dev_data:
self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1


def _train_step(self, data_iterator, network, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs):
"""Training process in one epoch. """Training process in one epoch.


kwargs should contain: kwargs should contain:
@@ -170,7 +88,7 @@ class Trainer(object):
- epoch: int, - epoch: int,
""" """
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
prediction = self.data_forward(model, batch_x)


# TODO: refactor self.get_loss # TODO: refactor self.get_loss
loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y) loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y)
@@ -179,35 +97,25 @@ class Trainer(object):
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self._model.named_parameters():
for name, param in self.model.named_parameters():
if param.requires_grad: if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if n_print > 0 and self.step % n_print == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"])) diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
self.epoch, self.step, loss.data, diff)
epoch, self.step, loss.data, diff)
print(print_output) print(print_output)
logger.info(print_output)
if self.validate and self.valid_step > 0 and self.step > 0 and self.step % self.valid_step == 0:
self.valid_model()

self.step += 1 self.step += 1


def valid_model(self):
if self.dev_data is None:
raise RuntimeError(
"self.validate is True in trainer, but dev_data is None. Please provide the validation data.")
logger.info("validation started")
res = self.validator.test(self._model, self.dev_data)
def do_validation(self):
res = self.tester.test(self.model, self.dev_data)
for name, num in res.items(): for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_best_dev and self.best_eval_result(res):
logger.info('save best result! {}'.format(res))
print('save best result! {}'.format(res))
self.save_model(self._model, 'best_model_' + self.start_time)
return res
self.save_model(self.model, 'best_model_' + self.start_time)


def mode(self, model, is_test=False): def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.
@@ -221,23 +129,11 @@ class Trainer(object):
else: else:
model.train() model.train()


def define_optimizer(self, optim=None):
"""Define framework-specific optimizer specified by the models.

"""
if optim is not None:
# optimizer constructed by user
self._optimizer = optim
elif self._optimizer is None:
# optimizer constructed by proto
self._optimizer = self._optimizer_proto.construct_from_pytorch(self._model.parameters())
return self._optimizer

def update(self): def update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
self._optimizer.step()
self.optimizer.step()


def data_forward(self, network, x): def data_forward(self, network, x):
y = network(**x) y = network(**x)
@@ -253,7 +149,7 @@ class Trainer(object):


For PyTorch, just do "loss.backward()" For PyTorch, just do "loss.backward()"
""" """
self._model.zero_grad()
self.model.zero_grad()
loss.backward() loss.backward()


def get_loss(self, predict, truth): def get_loss(self, predict, truth):
@@ -264,68 +160,37 @@ class Trainer(object):
:return: a scalar :return: a scalar
""" """
if isinstance(predict, dict) and isinstance(truth, dict): if isinstance(predict, dict) and isinstance(truth, dict):
return self._loss_func(**predict, **truth)
return self.loss_func(**predict, **truth)
if len(truth) > 1: if len(truth) > 1:
raise NotImplementedError("Not ready to handle multi-labels.") raise NotImplementedError("Not ready to handle multi-labels.")
truth = list(truth.values())[0] if len(truth) > 0 else None truth = list(truth.values())[0] if len(truth) > 0 else None
return self._loss_func(predict, truth)

def define_loss(self):
"""Define a loss for the trainer.
return self.loss_func(predict, truth)


If the model defines a loss, use model's loss.
Otherwise, Trainer must has a loss argument, use it as loss.
These two losses cannot be defined at the same time.
Trainer does not handle loss definition or choose default losses.
"""
# if hasattr(self._model, "loss") and self._loss_func is not None:
# raise ValueError("Both the model and Trainer define loss. Please take out your loss.")

if hasattr(self._model, "loss"):
self._loss_func = self._model.loss
logger.info("The model has a loss function, use it.")
def save_model(self, model, model_name, only_param=False):
if only_param:
torch.save(model.state_dict(), model_name)
else: else:
if self._loss_func is None:
raise ValueError("Please specify a loss function.")
logger.info("The model didn't define loss, use Trainer's loss.")
torch.save(model, model_name)


def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.


:param validator: a Tester instance
:return: bool, True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics
def best_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.


if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False

def save_model(self, network, model_name):
"""Save this model with such a name.
This method may be called multiple times by Trainer to overwritten a better model.

:param network: the PyTorch model
:param model_name: str
"""
if model_name[-4:] != ".pkl":
model_name += ".pkl"
ModelSaver(os.path.join(self.pickle_path, model_name)).save_pytorch(network)

def _create_validator(self, valid_args):
return Tester(**valid_args)

def set_validator(self, validor):
self.validator = validor
:return: bool, True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics


if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False

+ 6
- 7
fastNLP/io/config_saver.py View File

@@ -1,7 +1,6 @@
import os import os


from fastNLP.io.config_loader import ConfigSection, ConfigLoader from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.logger import create_logger




class ConfigSaver(object): class ConfigSaver(object):
@@ -61,8 +60,8 @@ class ConfigSaver(object):
continue continue


if '=' not in line: if '=' not in line:
log = create_logger(__name__, './config_saver.log')
log.error("can NOT load config file [%s]" % self.file_path)
# log = create_logger(__name__, './config_saver.log')
# log.error("can NOT load config file [%s]" % self.file_path)
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) raise RuntimeError("can NOT load config file {}".__format__(self.file_path))


key = line.split('=', maxsplit=1)[0].strip() key = line.split('=', maxsplit=1)[0].strip()
@@ -123,10 +122,10 @@ class ConfigSaver(object):
change_file = True change_file = True
break break
if section_file[k] != section[k]: if section_file[k] != section[k]:
logger = create_logger(__name__, "./config_loader.log")
logger.warning("section [%s] in config file [%s] has been changed" % (
section_name, self.file_path
))
# logger = create_logger(__name__, "./config_loader.log")
# logger.warning("section [%s] in config file [%s] has been changed" % (
# section_name, self.file_path
#))
change_file = True change_file = True
break break
if not change_file: if not change_file:


Loading…
Cancel
Save