|
@@ -216,8 +216,8 @@ class Trainer(TrainerEventTrigger): |
|
|
callbacks=callbacks, |
|
|
callbacks=callbacks, |
|
|
metrics=metrics, |
|
|
metrics=metrics, |
|
|
evaluate_every=evaluate_every, |
|
|
evaluate_every=evaluate_every, |
|
|
input_mapping=evaluate_input_mapping, |
|
|
|
|
|
output_mapping=evaluate_output_mapping, |
|
|
|
|
|
|
|
|
input_mapping=train_input_mapping, |
|
|
|
|
|
output_mapping=train_output_mapping, |
|
|
model_wo_auto_param_call=model_wo_auto_param_call, |
|
|
model_wo_auto_param_call=model_wo_auto_param_call, |
|
|
accumulation_steps=accumulation_steps, |
|
|
accumulation_steps=accumulation_steps, |
|
|
fp16=fp16, |
|
|
fp16=fp16, |
|
@@ -274,8 +274,8 @@ class Trainer(TrainerEventTrigger): |
|
|
progress_bar = progress_bar.name |
|
|
progress_bar = progress_bar.name |
|
|
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, |
|
|
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, |
|
|
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, |
|
|
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, |
|
|
evaluate_fn=evaluate_fn, input_mapping=input_mapping, |
|
|
|
|
|
output_mapping=output_mapping, fp16=fp16, verbose=0, |
|
|
|
|
|
|
|
|
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, |
|
|
|
|
|
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, |
|
|
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), |
|
|
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), |
|
|
progress_bar=progress_bar) |
|
|
progress_bar=progress_bar) |
|
|
|
|
|
|
|
|