Browse Source

* 添加callbacks:EarlyStopCallback

* 将dataset.py中的assert改为raise error
* 给trainer添加try-except,捕捉EarlyStopError
* 优化trainer代码
* 给callbacks添加测试
tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
d80d944e40
5 changed files with 119 additions and 38 deletions
  1. +36
    -4
      fastNLP/core/callback.py
  2. +4
    -1
      fastNLP/core/dataset.py
  3. +14
    -15
      fastNLP/core/trainer.py
  4. +13
    -0
      test/core/test_batch.py
  5. +52
    -18
      test/core/test_callbacks.py

+ 36
- 4
fastNLP/core/callback.py View File

@@ -69,16 +69,16 @@ class Callback(object):
""" """
pass pass


def on_exception(self, exception, model, indices):
def on_exception(self, exception, model):
""" """
当训练过程出现异常,会触发该方法 当训练过程出现异常,会触发该方法
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 :param exception: 某种类型的Exception,比如KeyboardInterrupt等
:param model: 传入Trainer的模型 :param model: 传入Trainer的模型
:param indices: 当前batch的index
:return: :return:
""" """
pass pass



def transfer(func): def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. """装饰器,将对CallbackManager的调用转发到各个Callback子类.


@@ -206,10 +206,10 @@ class EchoCallback(Callback):
def after_train(self, model): def after_train(self, model):
print("after_train") print("after_train")



class GradientClipCallback(Callback): class GradientClipCallback(Callback):
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
"""
每次backward前,将parameter的gradient clip到某个范围。
"""每次backward前,将parameter的gradient clip到某个范围。


:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer :param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer
的model中所有参数进行clip 的model中所有参数进行clip
@@ -235,6 +235,38 @@ class GradientClipCallback(Callback):
self.clip_fun(model.parameters(), self.clip_value) self.clip_fun(model.parameters(), self.clip_value)




class EarlyStopError(BaseException):
def __init__(self, msg):
super(EarlyStopError, self).__init__(msg)


class EarlyStopCallback(Callback):
def __init__(self, patience):
"""

:param int patience: 停止之前等待的epoch数
"""
super(EarlyStopCallback, self).__init__()
self.trainer = None # override by CallbackManager
self.patience = patience
self.wait = 0
self.epoch = 0

def after_valid(self, eval_result, metric_key, optimizer):
self.epoch += 1
if not self.trainer._better_eval_result(eval_result):
# current result is getting worse
if self.wait == self.patience:
raise EarlyStopError("Early stopping raised.")
else:
self.wait += 1
else:
self.wait = 0

def on_exception(self, exception, model):
if isinstance(exception, EarlyStopError):
print("Early Stopping triggered in epoch {}!".format(self.epoch))



if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])


+ 4
- 1
fastNLP/core/dataset.py View File

@@ -146,7 +146,10 @@ class DataSet(object):
for name, field in ins.fields.items(): for name, field in ins.fields.items():
self.field_arrays[name] = FieldArray(name, [field]) self.field_arrays[name] = FieldArray(name, [field])
else: else:
assert len(self.field_arrays) == len(ins.fields)
if len(self.field_arrays) != len(ins.fields):
raise ValueError(
"DataSet object has {} fields, but attempt to append an Instance object with {} fields."
.format(len(self.field_arrays), len(ins.fields)))
for name, field in ins.fields.items(): for name, field in ins.fields.items():
assert name in self.field_arrays assert name in self.field_arrays
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


+ 14
- 15
fastNLP/core/trainer.py View File

@@ -181,7 +181,6 @@ class Trainer(object):
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda() self.model = self.model.cuda()
self._model_device = self.model.parameters().__next__().device self._model_device = self.model.parameters().__next__().device

self._mode(self.model, is_test=False) self._mode(self.model, is_test=False)


self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
@@ -200,9 +199,12 @@ class Trainer(object):
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)


self.callback_manager.before_train()
self._train()
self.callback_manager.after_train(self.model)
try:
self.callback_manager.before_train()
self._train()
self.callback_manager.after_train(self.model)
except BaseException as e:
self.callback_manager.on_exception(e, self.model)


if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -231,10 +233,11 @@ class Trainer(object):
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0
start = time.time() start = time.time()
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
total_steps = data_iterator.num_batches * self.n_epochs
total_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
@@ -291,17 +294,13 @@ class Trainer(object):
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str) pbar.write(eval_str)


# if self.validate_every < 0 and self.dev_data:
# eval_res = self._do_validation(epoch=epoch, step=self.step)
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
# self.tester._format_eval_results(eval_res)
# pbar.write(eval_str)
if epoch != self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
# ================= mini-batch end ==================== #

# lr decay; early stopping # lr decay; early stopping
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
# =============== epochs end =================== #
pbar.close() pbar.close()
# ============ tqdm end ============== #


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
res = self.tester.test() res = self.tester.test()
@@ -314,7 +313,7 @@ class Trainer(object):
self._save_model(self.model, self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))
else: else:
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()}
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()}
self.best_dev_perf = res self.best_dev_perf = res
self.best_dev_epoch = epoch self.best_dev_epoch = epoch
self.best_dev_step = step self.best_dev_step = step


+ 13
- 0
test/core/test_batch.py View File

@@ -6,6 +6,7 @@ import torch
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.dataset import construct_dataset from fastNLP.core.dataset import construct_dataset
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler




@@ -76,3 +77,15 @@ class TestCase1(unittest.TestCase):
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
self.assertTrue(isinstance(y["y"], torch.Tensor)) self.assertTrue(isinstance(y["y"], torch.Tensor))
self.assertEqual(tuple(y["y"].shape), (4, 4)) self.assertEqual(tuple(y["y"].shape), (4, 4))

def test_list_of_list_to_tensor(self):
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] +
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4))
self.assertTrue(isinstance(y["y"], torch.Tensor))
self.assertEqual(tuple(y["y"].shape), (4, 4))

+ 52
- 18
test/core/test_callbacks.py View File

@@ -2,39 +2,43 @@ import unittest


import numpy as np import numpy as np


from fastNLP.core.callback import EchoCallback
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss from fastNLP.core.losses import BCELoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier




class TestCallback(unittest.TestCase):
def test_case(self):
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
def prepare_env():
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))


mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))


data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


data_set = prepare_fake_dataset()
data_set.set_input("x")
data_set.set_target("y")
data_set = prepare_fake_dataset()
data_set.set_input("x")
data_set.set_target("y")
model = NaiveClassifier(2, 1)
return data_set, model


model = NaiveClassifier(2, 1)


class TestCallback(unittest.TestCase):
def test_echo_callback(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"), loss=BCELoss(pred="predict", target="y"),
n_epochs=1,
n_epochs=2,
batch_size=32, batch_size=32,
print_every=50, print_every=50,
optimizer=SGD(lr=0.1), optimizer=SGD(lr=0.1),
@@ -42,3 +46,33 @@ class TestCallback(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
callbacks=[EchoCallback()]) callbacks=[EchoCallback()])
trainer.train() trainer.train()

def test_gradient_clip(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=30,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)])
trainer.train()

def test_early_stop(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.01),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[EarlyStopCallback(5)])
trainer.train()

Loading…
Cancel
Save