Browse Source

progres_bar统一

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
e1fbc2cfeb
4 changed files with 11 additions and 6 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +2
    -1
      fastNLP/core/callbacks/__init__.py
  3. +2
    -1
      fastNLP/core/callbacks/progress_callback.py
  4. +6
    -4
      fastNLP/core/controllers/trainer.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -7,6 +7,7 @@ __all__ = [
'ProgressCallback',
'RichCallback',
'TqdmCallback',
'RawTextCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",


+ 2
- 1
fastNLP/core/callbacks/__init__.py View File

@@ -8,6 +8,7 @@ __all__ = [
'ProgressCallback',
'RichCallback',
'TqdmCallback',
'RawTextCallback',

"LRSchedCallback",
'LoadBestModelCallback',
@@ -31,7 +32,7 @@ from .callback import Callback
from .callback_event import Event, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback, RawTextCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback


+ 2
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -5,7 +5,8 @@ __all__ = [
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
'TqdmCallback'
'TqdmCallback',
'RawTextCallback'
]




+ 6
- 4
fastNLP/core/controllers/trainer.py View File

@@ -36,6 +36,7 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException
from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.callbacks.progress_callback import ProgressCallback


class Trainer(TrainerEventTrigger):
@@ -554,15 +555,16 @@ class Trainer(TrainerEventTrigger):
evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
progress_bar_name = None
for callback in self.callback_manager.class_callbacks:
if isinstance(callback, ProgressCallback):
progress_bar_name = callback.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
progress_bar=progress_bar_name,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
else:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")


Loading…
Cancel
Save