diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index e4d67261..47959fd2 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -180,12 +180,11 @@ class Tester(object): f"`dict`, got {type(eval_result)}") metric_name = metric.get_metric_name() eval_results[metric_name] = eval_result - + pbar.close() end_time = time.time() test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' # pbar.write(test_str) self.logger.info(test_str) - pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 783997a7..787ea313 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -336,7 +336,7 @@ except: import warnings from .batch import DataSetIter, BatchIter -from .callback import CallbackManager, CallbackException +from .callback import CallbackManager, CallbackException, Callback from .dataset import DataSet from .losses import _prepare_losser from .metrics import _prepare_metrics @@ -422,13 +422,8 @@ class Trainer(object): batch_size=32, sampler=None, drop_last=False, update_every=1, num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, - validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, + validate_every=-1, save_path=None, use_tqdm=True, device=None, callbacks=None, check_code_level=0, **kwargs): - if prefetch and num_workers==0: - num_workers = 1 - if prefetch: - warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") - super(Trainer, self).__init__() if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") @@ -566,6 +561,9 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + if isinstance(callbacks, Callback): + callbacks = [callbacks] + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) @@ -617,8 +615,8 @@ class Trainer(object): if self.dev_data is not None and self.best_dev_perf is not None: self.logger.info( - "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + - self.tester._format_eval_results(self.best_dev_perf), ) + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) + self.logger.info(self.tester._format_eval_results(self.best_dev_perf)) results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step