Browse Source

update callbacks in Trainer

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
a7274c667c
3 changed files with 73 additions and 36 deletions
  1. +6
    -1
      fastNLP/core/batch.py
  2. +36
    -21
      fastNLP/core/callback.py
  3. +31
    -14
      fastNLP/core/trainer.py

+ 6
- 1
fastNLP/core/batch.py View File

@@ -26,7 +26,8 @@ class Batch(object):
self.as_numpy = as_numpy self.as_numpy = as_numpy
self.idx_list = None self.idx_list = None
self.curidx = 0 self.curidx = 0
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0)
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None


def __iter__(self): def __iter__(self):
self.idx_list = self.sampler(self.dataset) self.idx_list = self.sampler(self.dataset)
@@ -42,6 +43,7 @@ class Batch(object):
batch_x, batch_y = {}, {} batch_x, batch_y = {}, {}


indices = self.idx_list[self.curidx:endidx] indices = self.idx_list[self.curidx:endidx]
self.cur_batch_indices = indices


for field_name, field in self.dataset.get_all_fields().items(): for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input: if field.is_target or field.is_input:
@@ -60,6 +62,9 @@ class Batch(object):
def __len__(self): def __len__(self):
return self.num_batches return self.num_batches


def get_batch_indices(self):
return self.cur_batch_indices



def to_tensor(batch, dtype): def to_tensor(batch, dtype):
if dtype in (int, np.int8, np.int16, np.int32, np.int64): if dtype in (int, np.int8, np.int16, np.int32, np.int64):


+ 36
- 21
fastNLP/core/callback.py View File

@@ -8,35 +8,35 @@ class Callback(object):
def __init__(self): def __init__(self):
super(Callback, self).__init__() super(Callback, self).__init__()


def before_train(self):
def before_train(self, *args):
# before the main training loop # before the main training loop
pass pass


def before_epoch(self):
def before_epoch(self, *args):
# at the beginning of each epoch # at the beginning of each epoch
pass pass


def before_batch(self):
def before_batch(self, *args):
# at the beginning of each step/mini-batch # at the beginning of each step/mini-batch
pass pass


def before_loss(self):
def before_loss(self, *args):
# after data_forward, and before loss computation # after data_forward, and before loss computation
pass pass


def before_backward(self):
def before_backward(self, *args):
# after loss computation, and before gradient backward # after loss computation, and before gradient backward
pass pass


def after_batch(self):
def after_batch(self, *args):
# at the end of each step/mini-batch # at the end of each step/mini-batch
pass pass


def after_epoch(self):
def after_epoch(self, *args):
# at the end of each epoch # at the end of each epoch
pass pass


def after_train(self):
def after_train(self, *args):
# after training loop # after training loop
pass pass


@@ -48,12 +48,12 @@ def transfer(func):
:return: :return:
""" """


def wrapper(manager):
def wrapper(manager, *arg):
returns = [] returns = []
for callback in manager.callbacks: for callback in manager.callbacks:
for env_name, env_value in manager.env.items(): for env_name, env_value in manager.env.items():
setattr(callback, env_name, env_value) setattr(callback, env_name, env_value)
returns.append(getattr(callback, func.__name__)())
returns.append(getattr(callback, func.__name__)(*arg))
return returns return returns


return wrapper return wrapper
@@ -91,19 +91,27 @@ class CallbackManager(Callback):
pass pass


@transfer @transfer
def before_epoch(self):
def before_epoch(self, cur_epoch, total_epoch):
pass pass


@transfer @transfer
def before_batch(self):
def before_batch(self, batch_x, batch_y, indices):
pass pass


@transfer @transfer
def before_loss(self):
def before_loss(self, batch_y, predict_y):
pass pass


@transfer @transfer
def before_backward(self):
def before_backward(self, loss, model):
pass

@transfer
def after_backward(self, model):
pass

@transfer
def after_step(self):
pass pass


@transfer @transfer
@@ -111,18 +119,25 @@ class CallbackManager(Callback):
pass pass


@transfer @transfer
def after_epoch(self):
def after_valid(self, eval_result, metric_key, optimizer):
pass pass


@transfer @transfer
def after_train(self):
def after_epoch(self, cur_epoch, n_epoch, optimizer):
pass

@transfer
def after_train(self, model):
pass

@transfer
def on_exception(self, exception, model, indices):
pass pass




class DummyCallback(Callback): class DummyCallback(Callback):
def before_train(self):
print("before train!!!")
print(self.n_epoch)
def before_train(self, *arg):
print(arg)


def after_epoch(self): def after_epoch(self):
print("after epoch!!!") print("after epoch!!!")
@@ -157,5 +172,5 @@ class EchoCallback(Callback):


if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.before_train()
print(manager.after_epoch())
manager.before_train(10, 11, 12)
# print(manager.after_epoch())

+ 31
- 14
fastNLP/core/trainer.py View File

@@ -203,7 +203,7 @@ class Trainer(object):
self._tqdm_train() self._tqdm_train()
else: else:
self._print_train() self._print_train()
self.callback_manager.after_train()
self.callback_manager.after_train(self.model)


if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -229,24 +229,35 @@ class Trainer(object):
self.step = 0 self.step = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
total_steps = data_iterator.num_batches*self.n_epochs
total_steps = data_iterator.num_batches * self.n_epochs
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0 avg_loss = 0
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
self.callback_manager.before_epoch()
# early stopping
self.callback_manager.before_epoch(epoch, self.n_epochs)
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.before_batch(batch_x, batch_y, indices)
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)


self.callback_manager.before_loss()
# edit prediction
self.callback_manager.before_loss(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y) loss = self._compute_loss(prediction, batch_y)
avg_loss += loss.item() avg_loss += loss.item()


self.callback_manager.before_backward()
# Is loss NaN or inf? requires_grad = False
self.callback_manager.before_backward(loss, self.model)
self._grad_backward(loss) self._grad_backward(loss)
# gradient clipping
self.callback_manager.after_backward(self.model)

self._update() self._update()
# lr scheduler; lr_finder; one_cycle
self.callback_manager.after_step()

self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if param.requires_grad: if param.requires_grad:
@@ -258,23 +269,27 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
pbar.update(self.print_every) pbar.update(self.print_every)
self.step += 1 self.step += 1
# do nothing
self.callback_manager.after_batch() self.callback_manager.after_batch()


if self.validate_every > 0 and self.step % self.validate_every == 0 \
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % self.batch_size == len(data_iterator))) \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str) pbar.write(eval_str)
if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
if epoch!=self.n_epochs:

# if self.validate_every < 0 and self.dev_data:
# eval_res = self._do_validation(epoch=epoch, step=self.step)
# eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
# self.tester._format_eval_results(eval_res)
# pbar.write(eval_str)
if epoch != self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False) as_numpy=False)
self.callback_manager.after_epoch()
# lr decay; early stopping
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
pbar.close() pbar.close()


def _print_train(self): def _print_train(self):
@@ -340,6 +355,8 @@ class Trainer(object):
self.best_dev_perf = res self.best_dev_perf = res
self.best_dev_epoch = epoch self.best_dev_epoch = epoch
self.best_dev_step = step self.best_dev_step = step
# get validation results; adjust optimizer
self.callback_manager.after_valid(res, self.metric_key, self.optimizer)
return res return res


def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):


Loading…
Cancel
Save