From e1fbc2cfebb14f21ec3d675408d01a59522d3dc8 Mon Sep 17 00:00:00 2001 From: yhcc Date: Sun, 3 Jul 2022 14:09:41 +0800 Subject: [PATCH] =?UTF-8?q?progres=5Fbar=E7=BB=9F=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 + fastNLP/core/callbacks/__init__.py | 3 ++- fastNLP/core/callbacks/progress_callback.py | 3 ++- fastNLP/core/controllers/trainer.py | 10 ++++++---- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 300a342f..28479b05 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -7,6 +7,7 @@ __all__ = [ 'ProgressCallback', 'RichCallback', 'TqdmCallback', + 'RawTextCallback', "LRSchedCallback", 'LoadBestModelCallback', "EarlyStopCallback", diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index d1f19b96..feff9f9b 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -8,6 +8,7 @@ __all__ = [ 'ProgressCallback', 'RichCallback', 'TqdmCallback', + 'RawTextCallback', "LRSchedCallback", 'LoadBestModelCallback', @@ -31,7 +32,7 @@ from .callback import Callback from .callback_event import Event, Filter from .callback_manager import CallbackManager from .checkpoint_callback import CheckpointCallback -from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback +from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback, RawTextCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 681ea4d3..eda0f564 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -5,7 +5,8 @@ __all__ = [ 'choose_progress_callback', 'ProgressCallback', 'RichCallback', - 'TqdmCallback' + 'TqdmCallback', + 'RawTextCallback' ] diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index ac934bd7..8a0b25ac 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -36,6 +36,7 @@ from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.utils.exceptions import EarlyStopException from fastNLP.core.dataloaders import OverfitDataLoader +from fastNLP.core.callbacks.progress_callback import ProgressCallback class Trainer(TrainerEventTrigger): @@ -554,15 +555,16 @@ class Trainer(TrainerEventTrigger): evaluate_dataloaders = self.dataloader if evaluate_dataloaders is not None: check_evaluate_every(evaluate_every) - progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 - if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 - progress_bar = progress_bar.name + progress_bar_name = None + for callback in self.callback_manager.class_callbacks: + if isinstance(callback, ProgressCallback): + progress_bar_name = callback.name self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), - progress_bar=progress_bar, + progress_bar=progress_bar_name, check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) else: raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")