From ce083de26b6c28b28c507ea25ad499586dc032f8 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 20 Aug 2019 16:04:51 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=88=A0=E9=99=A4Trainer=E4=B8=AD=E7=9A=84pr?= =?UTF-8?q?efetch=E5=8F=82=E6=95=B0;=202.=E5=A2=9E=E5=8A=A0=E4=B8=AD?= =?UTF-8?q?=E6=96=87=E5=88=86=E8=AF=8D=E7=9A=84=E4=B8=8B=E8=BD=BD;=203.?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0DataBundle=E7=9A=84delete=5Fdataset,=20delete?= =?UTF-8?q?=5Fvocab?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 3 +-- fastNLP/core/trainer.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index e4d67261..47959fd2 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -180,12 +180,11 @@ class Tester(object): f"`dict`, got {type(eval_result)}") metric_name = metric.get_metric_name() eval_results[metric_name] = eval_result - + pbar.close() end_time = time.time() test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' # pbar.write(test_str) self.logger.info(test_str) - pbar.close() except _CheckError as e: prev_func_signature = _get_func_signature(self._predict_func) _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 783997a7..787ea313 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -336,7 +336,7 @@ except: import warnings from .batch import DataSetIter, BatchIter -from .callback import CallbackManager, CallbackException +from .callback import CallbackManager, CallbackException, Callback from .dataset import DataSet from .losses import _prepare_losser from .metrics import _prepare_metrics @@ -422,13 +422,8 @@ class Trainer(object): batch_size=32, sampler=None, drop_last=False, update_every=1, num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, - validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, + validate_every=-1, save_path=None, use_tqdm=True, device=None, callbacks=None, check_code_level=0, **kwargs): - if prefetch and num_workers==0: - num_workers = 1 - if prefetch: - warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") - super(Trainer, self).__init__() if not isinstance(model, nn.Module): raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") @@ -566,6 +561,9 @@ class Trainer(object): self.step = 0 self.start_time = None # start timestamp + if isinstance(callbacks, Callback): + callbacks = [callbacks] + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) @@ -617,8 +615,8 @@ class Trainer(object): if self.dev_data is not None and self.best_dev_perf is not None: self.logger.info( - "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + - self.tester._format_eval_results(self.best_dev_perf), ) + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) + self.logger.info(self.tester._format_eval_results(self.best_dev_perf)) results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step