From b9b0b5343036b47654895bebc20249a3c8882ec0 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Wed, 13 Apr 2022 19:09:27 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=B0=86=20Events=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=E5=B0=8F=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_events.py | 48 +++++++++---------- .../test_checkpoint_callback_torch.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 1c805ac2..7a25c45a 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -74,30 +74,30 @@ class EventEnum(_SingleEventState, Enum): @unique class Events(EventEnum): - ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" - ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" - ON_SANITY_CHECK_END = "on_sanity_check_end" - ON_TRAIN_BEGIN = "on_train_begin" - ON_TRAIN_END = "on_train_end" - ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" - ON_TRAIN_EPOCH_END = "on_train_epoch_end" - ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" - ON_FETCH_DATA_END = "on_fetch_data_end" - ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" - ON_TRAIN_BATCH_END = "on_train_batch_end" - ON_EXCEPTION = "on_exception" - ON_SAVE_MODEL = "on_save_model" - ON_LOAD_MODEL = "on_load_model" - ON_SAVE_CHECKPOINT = "on_save_checkpoint" - ON_LOAD_CHECKPOINT = "on_load_checkpoint" - ON_BEFORE_BACKWARD = "on_before_backward" - ON_AFTER_BACKWARD = "on_after_backward" - ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" - ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" - ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" - ON_AFTER_ZERO_GRAD = "on_after_zero_grad" - ON_VALIDATE_BEGIN = "on_validate_begin" - ON_VALIDATE_END = "on_validate_end" + on_after_trainer_initialized = "on_after_trainer_initialized" + on_sanity_check_begin = "on_sanity_check_begin" + on_sanity_check_end = "on_sanity_check_end" + on_train_begin = "on_train_begin" + on_train_end = "on_train_end" + on_train_epoch_begin = "on_train_epoch_begin" + on_train_epoch_end = "on_train_epoch_end" + on_fetch_data_begin = "on_fetch_data_begin" + on_fetch_data_end = "on_fetch_data_end" + on_train_batch_begin = "on_train_batch_begin" + on_train_batch_end = "on_train_batch_end" + on_exception = "on_exception" + on_save_model = "on_save_model" + on_load_model = "on_load_model" + on_save_checkpoint = "on_save_checkpoint" + on_load_checkpoint = "on_load_checkpoint" + on_before_backward = "on_before_backward" + on_after_backward = "on_after_backward" + on_before_optimizers_step = "on_before_optimizers_step" + on_after_optimizers_step = "on_after_optimizers_step" + on_before_zero_grad = "on_before_zero_grad" + on_after_zero_grad = "on_after_zero_grad" + on_validate_begin = "on_validate_begin" + on_validate_end = "on_validate_end" class EventsList: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 557c31b2..fe0a3582 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: raise NotImplementedError diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 0da8c976..82fa3af0 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -254,7 +254,7 @@ def test_trainer_on_exception( ): from fastNLP.core.callbacks.callback_events import Events - @Trainer.on(Events.ON_TRAIN_EPOCH_END) + @Trainer.on(Events.on_train_epoch_end) def raise_exception(trainer): if trainer.driver.get_local_rank() == cur_rank: raise NotImplementedError From 2f23d80ccc19645bca43d44cdefd208778065a6f Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 14 Apr 2022 00:45:17 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20trainer=20?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=20validate=20=E7=9A=84=E8=B0=83=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_events.py | 16 +---- fastNLP/core/controllers/trainer.py | 60 +++++++++---------- .../test_trainer_w_evaluator_torch.py | 44 +++++++++++++- 3 files changed, 73 insertions(+), 47 deletions(-) diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 7a25c45a..ef972b35 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -171,20 +171,8 @@ class Filter: self.num_called += 1 # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; - # 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; - # 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; - - # 提取参数的逻辑; - trainer = kwargs.get("trainer", None) - - if trainer is None and len(args) > 0: - trainer = args[0] - if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer - # 也会调用这个 module,但是 Controller 不会; - param = (self, trainer) - else: - param = (self, ) - if self._filter(*param): + trainer = args[0] + if self._filter(self, trainer): self.num_executed += 1 return fn(*args, **kwargs) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d8e984a1..e1f31375 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -224,13 +224,14 @@ class Trainer(TrainerEventTrigger): # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; self.evaluator = None - self.epoch_validate = lambda *args, **kwargs: ... - self.step_validate = lambda *args, **kwargs: ... self.monitor = monitor self.larger_better = larger_better if metrics is not None and validate_dataloaders is not None: if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") + if callable(validate_every): + logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " + "and in this way, the kind of controlling frequency is depending on the 'step'.") self.evaluator = Evaluator( model=model, @@ -248,16 +249,6 @@ class Trainer(TrainerEventTrigger): progress_bar=kwargs.get('progress_bar', 'auto') ) - if callable(validate_every): - self._step_validate_filter = Filter(filter_fn=validate_every) - logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " - "and in this way, the kind of controlling frequency is depending on the 'step'.") - elif validate_every < 0: - self._epoch_validate_filter = Filter(every=-validate_every) - else: - # validate_every > 0 - self._step_validate_filter = Filter(every=validate_every) - self.metrics = metrics self.validate_every = validate_every @@ -356,31 +347,38 @@ class Trainer(TrainerEventTrigger): raise e def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): - def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: + def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: trainer.on_validate_begin() _validate_res: dict = validate_fn() trainer.on_validate_end(_validate_res) + self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + + def step_validate(self): if self.evaluator is not None: + should_run_validate = False + if callable(self.validate_every): - self.step_validate = self._step_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) - elif self.validate_every < 0: - self.epoch_validate = self._epoch_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) - else: - # validate_every > 0 - self.step_validate = self._step_validate_filter(partial( - _validate_fn, - partial(self.evaluator.run, num_eval_batch_per_dl), - self - )) + if self.validate_every(self): + should_run_validate = True + elif self.validate_every > 0: + if self.global_forward_batches % self.validate_every == 0: + should_run_validate = True + + if should_run_validate: + self.validate_fn() + + def epoch_validate(self): + if self.evaluator is not None: + should_run_validate = False + + if isinstance(self.validate_every, int) and self.validate_every < 0: + validate_every = -self.validate_every + if self.cur_epoch_idx % validate_every == 0: + should_run_validate = True + + if should_run_validate: + self.validate_fn() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): r""" diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 699ee3b9..70d03f8c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -98,14 +98,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) +@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) +@pytest.mark.parametrize("validate_every", [-3]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, callbacks, + validate_every, n_epochs=10, ): trainer = Trainer( @@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator( input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, + validate_every=validate_every, n_epochs=n_epochs, callbacks=callbacks, output_from_new_proc="all" - ) trainer.run() @@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( dist.destroy_process_group() +@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) +@magic_argv_env_context +def test_trainer_validate_every( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=6, +): + + def validate_every(trainer): + if trainer.global_forward_batches % 10 == 0: + print(trainer) + print("\nfastNLP test validate every.\n") + print(trainer.global_forward_batches) + return True + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + validate_dataloaders=model_and_optimizers.validate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + output_from_new_proc="all", + validate_every=validate_every + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + +