|
|
@@ -36,6 +36,7 @@ from fastNLP.core.log import logger |
|
|
|
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME |
|
|
|
from fastNLP.core.utils.exceptions import EarlyStopException |
|
|
|
from fastNLP.core.dataloaders import OverfitDataLoader |
|
|
|
from fastNLP.core.callbacks.progress_callback import ProgressCallback |
|
|
|
|
|
|
|
|
|
|
|
class Trainer(TrainerEventTrigger): |
|
|
@@ -554,15 +555,16 @@ class Trainer(TrainerEventTrigger): |
|
|
|
evaluate_dataloaders = self.dataloader |
|
|
|
if evaluate_dataloaders is not None: |
|
|
|
check_evaluate_every(evaluate_every) |
|
|
|
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 |
|
|
|
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 |
|
|
|
progress_bar = progress_bar.name |
|
|
|
progress_bar_name = None |
|
|
|
for callback in self.callback_manager.class_callbacks: |
|
|
|
if isinstance(callback, ProgressCallback): |
|
|
|
progress_bar_name = callback.name |
|
|
|
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, |
|
|
|
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, |
|
|
|
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, |
|
|
|
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, |
|
|
|
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), |
|
|
|
progress_bar=progress_bar, |
|
|
|
progress_bar=progress_bar_name, |
|
|
|
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) |
|
|
|
else: |
|
|
|
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") |
|
|
|