diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 05160312..9ba8dca8 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -26,7 +26,8 @@ class Batch(object): self.as_numpy = as_numpy self.idx_list = None self.curidx = 0 - self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0) + self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) + self.cur_batch_indices = None def __iter__(self): self.idx_list = self.sampler(self.dataset) @@ -42,6 +43,7 @@ class Batch(object): batch_x, batch_y = {}, {} indices = self.idx_list[self.curidx:endidx] + self.cur_batch_indices = indices for field_name, field in self.dataset.get_all_fields().items(): if field.is_target or field.is_input: @@ -60,6 +62,9 @@ class Batch(object): def __len__(self): return self.num_batches + def get_batch_indices(self): + return self.cur_batch_indices + def to_tensor(batch, dtype): if dtype in (int, np.int8, np.int16, np.int32, np.int64): diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b172f3a4..8b6bfdc2 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -8,35 +8,35 @@ class Callback(object): def __init__(self): super(Callback, self).__init__() - def before_train(self): + def before_train(self, *args): # before the main training loop pass - def before_epoch(self): + def before_epoch(self, *args): # at the beginning of each epoch pass - def before_batch(self): + def before_batch(self, *args): # at the beginning of each step/mini-batch pass - def before_loss(self): + def before_loss(self, *args): # after data_forward, and before loss computation pass - def before_backward(self): + def before_backward(self, *args): # after loss computation, and before gradient backward pass - def after_batch(self): + def after_batch(self, *args): # at the end of each step/mini-batch pass - def after_epoch(self): + def after_epoch(self, *args): # at the end of each epoch pass - def after_train(self): + def after_train(self, *args): # after training loop pass @@ -48,12 +48,12 @@ def transfer(func): :return: """ - def wrapper(manager): + def wrapper(manager, *arg): returns = [] for callback in manager.callbacks: for env_name, env_value in manager.env.items(): setattr(callback, env_name, env_value) - returns.append(getattr(callback, func.__name__)()) + returns.append(getattr(callback, func.__name__)(*arg)) return returns return wrapper @@ -91,19 +91,27 @@ class CallbackManager(Callback): pass @transfer - def before_epoch(self): + def before_epoch(self, cur_epoch, total_epoch): pass @transfer - def before_batch(self): + def before_batch(self, batch_x, batch_y, indices): pass @transfer - def before_loss(self): + def before_loss(self, batch_y, predict_y): pass @transfer - def before_backward(self): + def before_backward(self, loss, model): + pass + + @transfer + def after_backward(self, model): + pass + + @transfer + def after_step(self): pass @transfer @@ -111,18 +119,25 @@ class CallbackManager(Callback): pass @transfer - def after_epoch(self): + def after_valid(self, eval_result, metric_key, optimizer): pass @transfer - def after_train(self): + def after_epoch(self, cur_epoch, n_epoch, optimizer): + pass + + @transfer + def after_train(self, model): + pass + + @transfer + def on_exception(self, exception, model, indices): pass class DummyCallback(Callback): - def before_train(self): - print("before train!!!") - print(self.n_epoch) + def before_train(self, *arg): + print(arg) def after_epoch(self): print("after epoch!!!") @@ -157,5 +172,5 @@ class EchoCallback(Callback): if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) - manager.before_train() - print(manager.after_epoch()) + manager.before_train(10, 11, 12) + # print(manager.after_epoch()) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index db0be67f..44d219b4 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -203,7 +203,7 @@ class Trainer(object): self._tqdm_train() else: self._print_train() - self.callback_manager.after_train() + self.callback_manager.after_train(self.model) if self.dev_data is not None: print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + @@ -229,24 +229,35 @@ class Trainer(object): self.step = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) - total_steps = data_iterator.num_batches*self.n_epochs + total_steps = data_iterator.num_batches * self.n_epochs with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: avg_loss = 0 for epoch in range(1, self.n_epochs+1): pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) - self.callback_manager.before_epoch() + # early stopping + self.callback_manager.before_epoch(epoch, self.n_epochs) for batch_x, batch_y in data_iterator: - self.callback_manager.before_batch() + indices = data_iterator.get_batch_indices() + # negative sampling; replace unknown; re-weight batch_y + self.callback_manager.before_batch(batch_x, batch_y, indices) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) prediction = self._data_forward(self.model, batch_x) - self.callback_manager.before_loss() + # edit prediction + self.callback_manager.before_loss(batch_y, prediction) loss = self._compute_loss(prediction, batch_y) avg_loss += loss.item() - self.callback_manager.before_backward() + # Is loss NaN or inf? requires_grad = False + self.callback_manager.before_backward(loss, self.model) self._grad_backward(loss) + # gradient clipping + self.callback_manager.after_backward(self.model) + self._update() + # lr scheduler; lr_finder; one_cycle + self.callback_manager.after_step() + self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): if param.requires_grad: @@ -258,23 +269,27 @@ class Trainer(object): avg_loss = 0 pbar.update(self.print_every) self.step += 1 + # do nothing self.callback_manager.after_batch() - if self.validate_every > 0 and self.step % self.validate_every == 0 \ + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or + (self.validate_every < 0 and self.step % self.batch_size == len(data_iterator))) \ and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ self.tester._format_eval_results(eval_res) pbar.write(eval_str) - if self.validate_every < 0 and self.dev_data: - eval_res = self._do_validation(epoch=epoch, step=self.step) - eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ - self.tester._format_eval_results(eval_res) - pbar.write(eval_str) - if epoch!=self.n_epochs: + + # if self.validate_every < 0 and self.dev_data: + # eval_res = self._do_validation(epoch=epoch, step=self.step) + # eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \ + # self.tester._format_eval_results(eval_res) + # pbar.write(eval_str) + if epoch != self.n_epochs: data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False) - self.callback_manager.after_epoch() + # lr decay; early stopping + self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) pbar.close() def _print_train(self): @@ -340,6 +355,8 @@ class Trainer(object): self.best_dev_perf = res self.best_dev_epoch = epoch self.best_dev_step = step + # get validation results; adjust optimizer + self.callback_manager.after_valid(res, self.metric_key, self.optimizer) return res def _mode(self, model, is_test=False):