Browse Source

修改了 trainer 中的 validate 的调用的逻辑

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
2f23d80ccc
3 changed files with 73 additions and 47 deletions
  1. +2
    -14
      fastNLP/core/callbacks/callback_events.py
  2. +29
    -31
      fastNLP/core/controllers/trainer.py
  3. +42
    -2
      tests/core/controllers/test_trainer_w_evaluator_torch.py

+ 2
- 14
fastNLP/core/callbacks/callback_events.py View File

@@ -171,20 +171,8 @@ class Filter:
self.num_called += 1 self.num_called += 1


# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑;
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中;

# 提取参数的逻辑;
trainer = kwargs.get("trainer", None)

if trainer is None and len(args) > 0:
trainer = args[0]
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer
# 也会调用这个 module,但是 Controller 不会;
param = (self, trainer)
else:
param = (self, )
if self._filter(*param):
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1 self.num_executed += 1
return fn(*args, **kwargs) return fn(*args, **kwargs)




+ 29
- 31
fastNLP/core/controllers/trainer.py View File

@@ -224,13 +224,14 @@ class Trainer(TrainerEventTrigger):
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来;
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次;
self.evaluator = None self.evaluator = None
self.epoch_validate = lambda *args, **kwargs: ...
self.step_validate = lambda *args, **kwargs: ...
self.monitor = monitor self.monitor = monitor
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None: if metrics is not None and validate_dataloaders is not None:
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, "
"and in this way, the kind of controlling frequency is depending on the 'step'.")


self.evaluator = Evaluator( self.evaluator = Evaluator(
model=model, model=model,
@@ -248,16 +249,6 @@ class Trainer(TrainerEventTrigger):
progress_bar=kwargs.get('progress_bar', 'auto') progress_bar=kwargs.get('progress_bar', 'auto')
) )


if callable(validate_every):
self._step_validate_filter = Filter(filter_fn=validate_every)
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, "
"and in this way, the kind of controlling frequency is depending on the 'step'.")
elif validate_every < 0:
self._epoch_validate_filter = Filter(every=-validate_every)
else:
# validate_every > 0
self._step_validate_filter = Filter(every=validate_every)

self.metrics = metrics self.metrics = metrics
self.validate_every = validate_every self.validate_every = validate_every


@@ -356,31 +347,38 @@ class Trainer(TrainerEventTrigger):
raise e raise e


def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None:
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None:
trainer.on_validate_begin() trainer.on_validate_begin()
_validate_res: dict = validate_fn() _validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res) trainer.on_validate_end(_validate_res)


self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))

def step_validate(self):
if self.evaluator is not None: if self.evaluator is not None:
should_run_validate = False

if callable(self.validate_every): if callable(self.validate_every):
self.step_validate = self._step_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
elif self.validate_every < 0:
self.epoch_validate = self._epoch_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
else:
# validate_every > 0
self.step_validate = self._step_validate_filter(partial(
_validate_fn,
partial(self.evaluator.run, num_eval_batch_per_dl),
self
))
if self.validate_every(self):
should_run_validate = True
elif self.validate_every > 0:
if self.global_forward_batches % self.validate_every == 0:
should_run_validate = True

if should_run_validate:
self.validate_fn()

def epoch_validate(self):
if self.evaluator is not None:
should_run_validate = False

if isinstance(self.validate_every, int) and self.validate_every < 0:
validate_every = -self.validate_every
if self.cur_epoch_idx % validate_every == 0:
should_run_validate = True

if should_run_validate:
self.validate_fn()


def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
r""" r"""


+ 42
- 2
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -98,14 +98,16 @@ def model_and_optimizers(request):




# 测试一下普通的情况; # 测试一下普通的情况;
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]])
@pytest.mark.parametrize("validate_every", [-3])
@magic_argv_env_context @magic_argv_env_context
def test_trainer_torch_with_evaluator( def test_trainer_torch_with_evaluator(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
callbacks, callbacks,
validate_every,
n_epochs=10, n_epochs=10,
): ):
trainer = Trainer( trainer = Trainer(
@@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator(
input_mapping=model_and_optimizers.input_mapping, input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping, output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics, metrics=model_and_optimizers.metrics,
validate_every=validate_every,


n_epochs=n_epochs, n_epochs=n_epochs,
callbacks=callbacks, callbacks=callbacks,
output_from_new_proc="all" output_from_new_proc="all"

) )


trainer.run() trainer.run()
@@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps(
dist.destroy_process_group() dist.destroy_process_group()




@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1)
@magic_argv_env_context
def test_trainer_validate_every(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=6,
):

def validate_every(trainer):
if trainer.global_forward_batches % 10 == 0:
print(trainer)
print("\nfastNLP test validate every.\n")
print(trainer.global_forward_batches)
return True

trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
validate_dataloaders=model_and_optimizers.validate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
output_from_new_proc="all",
validate_every=validate_every
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()





Loading…
Cancel
Save