diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 59a4501b..9de400ab 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -216,8 +216,8 @@ class Trainer(TrainerEventTrigger): callbacks=callbacks, metrics=metrics, evaluate_every=evaluate_every, - input_mapping=evaluate_input_mapping, - output_mapping=evaluate_output_mapping, + input_mapping=train_input_mapping, + output_mapping=train_output_mapping, model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, @@ -274,8 +274,8 @@ class Trainer(TrainerEventTrigger): progress_bar = progress_bar.name self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, input_mapping=input_mapping, - output_mapping=output_mapping, fp16=fp16, verbose=0, + evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), progress_bar=progress_bar)