|
@@ -224,13 +224,14 @@ class Trainer(TrainerEventTrigger): |
|
|
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; |
|
|
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; |
|
|
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; |
|
|
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; |
|
|
self.evaluator = None |
|
|
self.evaluator = None |
|
|
self.epoch_validate = lambda *args, **kwargs: ... |
|
|
|
|
|
self.step_validate = lambda *args, **kwargs: ... |
|
|
|
|
|
self.monitor = monitor |
|
|
self.monitor = monitor |
|
|
self.larger_better = larger_better |
|
|
self.larger_better = larger_better |
|
|
if metrics is not None and validate_dataloaders is not None: |
|
|
if metrics is not None and validate_dataloaders is not None: |
|
|
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): |
|
|
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): |
|
|
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") |
|
|
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") |
|
|
|
|
|
if callable(validate_every): |
|
|
|
|
|
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " |
|
|
|
|
|
"and in this way, the kind of controlling frequency is depending on the 'step'.") |
|
|
|
|
|
|
|
|
self.evaluator = Evaluator( |
|
|
self.evaluator = Evaluator( |
|
|
model=model, |
|
|
model=model, |
|
@@ -248,16 +249,6 @@ class Trainer(TrainerEventTrigger): |
|
|
progress_bar=kwargs.get('progress_bar', 'auto') |
|
|
progress_bar=kwargs.get('progress_bar', 'auto') |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if callable(validate_every): |
|
|
|
|
|
self._step_validate_filter = Filter(filter_fn=validate_every) |
|
|
|
|
|
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " |
|
|
|
|
|
"and in this way, the kind of controlling frequency is depending on the 'step'.") |
|
|
|
|
|
elif validate_every < 0: |
|
|
|
|
|
self._epoch_validate_filter = Filter(every=-validate_every) |
|
|
|
|
|
else: |
|
|
|
|
|
# validate_every > 0 |
|
|
|
|
|
self._step_validate_filter = Filter(every=validate_every) |
|
|
|
|
|
|
|
|
|
|
|
self.metrics = metrics |
|
|
self.metrics = metrics |
|
|
self.validate_every = validate_every |
|
|
self.validate_every = validate_every |
|
|
|
|
|
|
|
@@ -356,31 +347,38 @@ class Trainer(TrainerEventTrigger): |
|
|
raise e |
|
|
raise e |
|
|
|
|
|
|
|
|
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): |
|
|
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): |
|
|
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: |
|
|
|
|
|
|
|
|
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: |
|
|
trainer.on_validate_begin() |
|
|
trainer.on_validate_begin() |
|
|
_validate_res: dict = validate_fn() |
|
|
_validate_res: dict = validate_fn() |
|
|
trainer.on_validate_end(_validate_res) |
|
|
trainer.on_validate_end(_validate_res) |
|
|
|
|
|
|
|
|
|
|
|
self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) |
|
|
|
|
|
|
|
|
|
|
|
def step_validate(self): |
|
|
if self.evaluator is not None: |
|
|
if self.evaluator is not None: |
|
|
|
|
|
should_run_validate = False |
|
|
|
|
|
|
|
|
if callable(self.validate_every): |
|
|
if callable(self.validate_every): |
|
|
self.step_validate = self._step_validate_filter(partial( |
|
|
|
|
|
_validate_fn, |
|
|
|
|
|
partial(self.evaluator.run, num_eval_batch_per_dl), |
|
|
|
|
|
self |
|
|
|
|
|
)) |
|
|
|
|
|
elif self.validate_every < 0: |
|
|
|
|
|
self.epoch_validate = self._epoch_validate_filter(partial( |
|
|
|
|
|
_validate_fn, |
|
|
|
|
|
partial(self.evaluator.run, num_eval_batch_per_dl), |
|
|
|
|
|
self |
|
|
|
|
|
)) |
|
|
|
|
|
else: |
|
|
|
|
|
# validate_every > 0 |
|
|
|
|
|
self.step_validate = self._step_validate_filter(partial( |
|
|
|
|
|
_validate_fn, |
|
|
|
|
|
partial(self.evaluator.run, num_eval_batch_per_dl), |
|
|
|
|
|
self |
|
|
|
|
|
)) |
|
|
|
|
|
|
|
|
if self.validate_every(self): |
|
|
|
|
|
should_run_validate = True |
|
|
|
|
|
elif self.validate_every > 0: |
|
|
|
|
|
if self.global_forward_batches % self.validate_every == 0: |
|
|
|
|
|
should_run_validate = True |
|
|
|
|
|
|
|
|
|
|
|
if should_run_validate: |
|
|
|
|
|
self.validate_fn() |
|
|
|
|
|
|
|
|
|
|
|
def epoch_validate(self): |
|
|
|
|
|
if self.evaluator is not None: |
|
|
|
|
|
should_run_validate = False |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.validate_every, int) and self.validate_every < 0: |
|
|
|
|
|
validate_every = -self.validate_every |
|
|
|
|
|
if self.cur_epoch_idx % validate_every == 0: |
|
|
|
|
|
should_run_validate = True |
|
|
|
|
|
|
|
|
|
|
|
if should_run_validate: |
|
|
|
|
|
self.validate_fn() |
|
|
|
|
|
|
|
|
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): |
|
|
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): |
|
|
r""" |
|
|
r""" |
|
|