From 68b010d361245b66aafcfa688c3d95f4f0ebdb9d Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 9 May 2022 16:39:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20Trainer=E4=B8=AD=E7=9A=84e?= =?UTF-8?q?valuate=5Fmapping=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9dd04a03..edaf7d65 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -207,8 +207,8 @@ class Trainer(TrainerEventTrigger): callbacks=callbacks, metrics=metrics, evaluate_every=evaluate_every, - input_mapping=evaluate_input_mapping, - output_mapping=evaluate_output_mapping, + input_mapping=train_input_mapping, + output_mapping=train_output_mapping, model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, @@ -265,8 +265,8 @@ class Trainer(TrainerEventTrigger): progress_bar = progress_bar.name self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, input_mapping=input_mapping, - output_mapping=output_mapping, fp16=fp16, verbose=0, + evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), progress_bar=progress_bar)