* 将dataset.py中的assert改为raise error * 给trainer添加try-except,捕捉EarlyStopError * 优化trainer代码 * 给callbacks添加测试tags/v0.3.1^2
@@ -69,16 +69,16 @@ class Callback(object): | |||
""" | |||
pass | |||
def on_exception(self, exception, model, indices): | |||
def on_exception(self, exception, model): | |||
""" | |||
当训练过程出现异常,会触发该方法 | |||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | |||
:param model: 传入Trainer的模型 | |||
:param indices: 当前batch的index | |||
:return: | |||
""" | |||
pass | |||
def transfer(func): | |||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||
@@ -206,10 +206,10 @@ class EchoCallback(Callback): | |||
def after_train(self, model): | |||
print("after_train") | |||
class GradientClipCallback(Callback): | |||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||
""" | |||
每次backward前,将parameter的gradient clip到某个范围。 | |||
"""每次backward前,将parameter的gradient clip到某个范围。 | |||
:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer | |||
的model中所有参数进行clip | |||
@@ -235,6 +235,38 @@ class GradientClipCallback(Callback): | |||
self.clip_fun(model.parameters(), self.clip_value) | |||
class EarlyStopError(BaseException): | |||
def __init__(self, msg): | |||
super(EarlyStopError, self).__init__(msg) | |||
class EarlyStopCallback(Callback): | |||
def __init__(self, patience): | |||
""" | |||
:param int patience: 停止之前等待的epoch数 | |||
""" | |||
super(EarlyStopCallback, self).__init__() | |||
self.trainer = None # override by CallbackManager | |||
self.patience = patience | |||
self.wait = 0 | |||
self.epoch = 0 | |||
def after_valid(self, eval_result, metric_key, optimizer): | |||
self.epoch += 1 | |||
if not self.trainer._better_eval_result(eval_result): | |||
# current result is getting worse | |||
if self.wait == self.patience: | |||
raise EarlyStopError("Early stopping raised.") | |||
else: | |||
self.wait += 1 | |||
else: | |||
self.wait = 0 | |||
def on_exception(self, exception, model): | |||
if isinstance(exception, EarlyStopError): | |||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
if __name__ == "__main__": | |||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | |||
@@ -146,7 +146,10 @@ class DataSet(object): | |||
for name, field in ins.fields.items(): | |||
self.field_arrays[name] = FieldArray(name, [field]) | |||
else: | |||
assert len(self.field_arrays) == len(ins.fields) | |||
if len(self.field_arrays) != len(ins.fields): | |||
raise ValueError( | |||
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." | |||
.format(len(self.field_arrays), len(ins.fields))) | |||
for name, field in ins.fields.items(): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
@@ -181,7 +181,6 @@ class Trainer(object): | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self.model = self.model.cuda() | |||
self._model_device = self.model.parameters().__next__().device | |||
self._mode(self.model, is_test=False) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S')) | |||
@@ -200,9 +199,12 @@ class Trainer(object): | |||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | |||
self._summary_writer = SummaryWriter(path) | |||
self.callback_manager.before_train() | |||
self._train() | |||
self.callback_manager.after_train(self.model) | |||
try: | |||
self.callback_manager.before_train() | |||
self._train() | |||
self.callback_manager.after_train(self.model) | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e, self.model) | |||
if self.dev_data is not None: | |||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
@@ -231,10 +233,11 @@ class Trainer(object): | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
start = time.time() | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||
total_steps = data_iterator.num_batches * self.n_epochs | |||
total_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) | |||
for epoch in range(1, self.n_epochs+1): | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
@@ -291,17 +294,13 @@ class Trainer(object): | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str) | |||
# if self.validate_every < 0 and self.dev_data: | |||
# eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ | |||
# self.tester._format_eval_results(eval_res) | |||
# pbar.write(eval_str) | |||
if epoch != self.n_epochs: | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, | |||
as_numpy=False) | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
res = self.tester.test() | |||
@@ -314,7 +313,7 @@ class Trainer(object): | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
else: | |||
self._best_model_states = {name:param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self.best_dev_perf = res | |||
self.best_dev_epoch = epoch | |||
self.best_dev_step = step | |||
@@ -6,6 +6,7 @@ import torch | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.dataset import construct_dataset | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.sampler import SequentialSampler | |||
@@ -76,3 +77,15 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
def test_list_of_list_to_tensor(self): | |||
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | |||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) |
@@ -2,39 +2,43 @@ import unittest | |||
import numpy as np | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
class TestCallback(unittest.TestCase): | |||
def test_case(self): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
def prepare_env(): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
return data_set, model | |||
model = NaiveClassifier(2, 1) | |||
class TestCallback(unittest.TestCase): | |||
def test_echo_callback(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=1, | |||
n_epochs=2, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
@@ -42,3 +46,33 @@ class TestCallback(unittest.TestCase): | |||
use_tqdm=False, | |||
callbacks=[EchoCallback()]) | |||
trainer.train() | |||
def test_gradient_clip(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=30, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||
trainer.train() | |||
def test_early_stop(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=50, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.01), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[EarlyStopCallback(5)]) | |||
trainer.train() |