@@ -74,30 +74,30 @@ class EventEnum(_SingleEventState, Enum): | |||||
@unique | @unique | ||||
class Events(EventEnum): | class Events(EventEnum): | ||||
ON_AFTER_TRAINER_INITIALIZED = "on_after_trainer_initialized" | |||||
ON_SANITY_CHECK_BEGIN = "on_sanity_check_begin" | |||||
ON_SANITY_CHECK_END = "on_sanity_check_end" | |||||
ON_TRAIN_BEGIN = "on_train_begin" | |||||
ON_TRAIN_END = "on_train_end" | |||||
ON_TRAIN_EPOCH_BEGIN = "on_train_epoch_begin" | |||||
ON_TRAIN_EPOCH_END = "on_train_epoch_end" | |||||
ON_FETCH_DATA_BEGIN = "on_fetch_data_begin" | |||||
ON_FETCH_DATA_END = "on_fetch_data_end" | |||||
ON_TRAIN_BATCH_BEGIN = "on_train_batch_begin" | |||||
ON_TRAIN_BATCH_END = "on_train_batch_end" | |||||
ON_EXCEPTION = "on_exception" | |||||
ON_SAVE_MODEL = "on_save_model" | |||||
ON_LOAD_MODEL = "on_load_model" | |||||
ON_SAVE_CHECKPOINT = "on_save_checkpoint" | |||||
ON_LOAD_CHECKPOINT = "on_load_checkpoint" | |||||
ON_BEFORE_BACKWARD = "on_before_backward" | |||||
ON_AFTER_BACKWARD = "on_after_backward" | |||||
ON_BEFORE_OPTIMIZERS_STEP = "on_before_optimizers_step" | |||||
ON_AFTER_OPTIMIZERS_STEP = "on_after_optimizers_step" | |||||
ON_BEFORE_ZERO_GRAD = "on_before_zero_grad" | |||||
ON_AFTER_ZERO_GRAD = "on_after_zero_grad" | |||||
ON_VALIDATE_BEGIN = "on_validate_begin" | |||||
ON_VALIDATE_END = "on_validate_end" | |||||
on_after_trainer_initialized = "on_after_trainer_initialized" | |||||
on_sanity_check_begin = "on_sanity_check_begin" | |||||
on_sanity_check_end = "on_sanity_check_end" | |||||
on_train_begin = "on_train_begin" | |||||
on_train_end = "on_train_end" | |||||
on_train_epoch_begin = "on_train_epoch_begin" | |||||
on_train_epoch_end = "on_train_epoch_end" | |||||
on_fetch_data_begin = "on_fetch_data_begin" | |||||
on_fetch_data_end = "on_fetch_data_end" | |||||
on_train_batch_begin = "on_train_batch_begin" | |||||
on_train_batch_end = "on_train_batch_end" | |||||
on_exception = "on_exception" | |||||
on_save_model = "on_save_model" | |||||
on_load_model = "on_load_model" | |||||
on_save_checkpoint = "on_save_checkpoint" | |||||
on_load_checkpoint = "on_load_checkpoint" | |||||
on_before_backward = "on_before_backward" | |||||
on_after_backward = "on_after_backward" | |||||
on_before_optimizers_step = "on_before_optimizers_step" | |||||
on_after_optimizers_step = "on_after_optimizers_step" | |||||
on_before_zero_grad = "on_before_zero_grad" | |||||
on_after_zero_grad = "on_after_zero_grad" | |||||
on_validate_begin = "on_validate_begin" | |||||
on_validate_end = "on_validate_end" | |||||
class EventsList: | class EventsList: | ||||
@@ -171,20 +171,8 @@ class Filter: | |||||
self.num_called += 1 | self.num_called += 1 | ||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | ||||
# 因此我们就可以这样进行操作,将 trainer 从 callback 函数的输入中取出来,送到我们的 trainer 里去,从而实现一些复杂的逻辑; | |||||
# 与此同时,当我们发现 Filter 所修饰的函数的输入第一个参数不是 trainer 时,我们就只传入一个 self 到 _filter 函数中; | |||||
# 提取参数的逻辑; | |||||
trainer = kwargs.get("trainer", None) | |||||
if trainer is None and len(args) > 0: | |||||
trainer = args[0] | |||||
if isinstance(trainer, fastNLP.Trainer): # 这里因为重复调用的问题,我们不能直接使用 fastNLP.Trainer,因为 Trainer | |||||
# 也会调用这个 module,但是 Controller 不会; | |||||
param = (self, trainer) | |||||
else: | |||||
param = (self, ) | |||||
if self._filter(*param): | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | self.num_executed += 1 | ||||
return fn(*args, **kwargs) | return fn(*args, **kwargs) | ||||
@@ -224,13 +224,14 @@ class Trainer(TrainerEventTrigger): | |||||
# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | # 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来; | ||||
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | # _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次; | ||||
self.evaluator = None | self.evaluator = None | ||||
self.epoch_validate = lambda *args, **kwargs: ... | |||||
self.step_validate = lambda *args, **kwargs: ... | |||||
self.monitor = monitor | self.monitor = monitor | ||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None and validate_dataloaders is not None: | if metrics is not None and validate_dataloaders is not None: | ||||
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): | ||||
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") | ||||
if callable(validate_every): | |||||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||||
self.evaluator = Evaluator( | self.evaluator = Evaluator( | ||||
model=model, | model=model, | ||||
@@ -248,16 +249,6 @@ class Trainer(TrainerEventTrigger): | |||||
progress_bar=kwargs.get('progress_bar', 'auto') | progress_bar=kwargs.get('progress_bar', 'auto') | ||||
) | ) | ||||
if callable(validate_every): | |||||
self._step_validate_filter = Filter(filter_fn=validate_every) | |||||
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, " | |||||
"and in this way, the kind of controlling frequency is depending on the 'step'.") | |||||
elif validate_every < 0: | |||||
self._epoch_validate_filter = Filter(every=-validate_every) | |||||
else: | |||||
# validate_every > 0 | |||||
self._step_validate_filter = Filter(every=validate_every) | |||||
self.metrics = metrics | self.metrics = metrics | ||||
self.validate_every = validate_every | self.validate_every = validate_every | ||||
@@ -356,31 +347,38 @@ class Trainer(TrainerEventTrigger): | |||||
raise e | raise e | ||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | ||||
def _validate_fn(validate_fn: Callable, trainer: Trainer) -> None: | |||||
def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: | |||||
trainer.on_validate_begin() | trainer.on_validate_begin() | ||||
_validate_res: dict = validate_fn() | _validate_res: dict = validate_fn() | ||||
trainer.on_validate_end(_validate_res) | trainer.on_validate_end(_validate_res) | ||||
self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) | |||||
def step_validate(self): | |||||
if self.evaluator is not None: | if self.evaluator is not None: | ||||
should_run_validate = False | |||||
if callable(self.validate_every): | if callable(self.validate_every): | ||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
elif self.validate_every < 0: | |||||
self.epoch_validate = self._epoch_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
else: | |||||
# validate_every > 0 | |||||
self.step_validate = self._step_validate_filter(partial( | |||||
_validate_fn, | |||||
partial(self.evaluator.run, num_eval_batch_per_dl), | |||||
self | |||||
)) | |||||
if self.validate_every(self): | |||||
should_run_validate = True | |||||
elif self.validate_every > 0: | |||||
if self.global_forward_batches % self.validate_every == 0: | |||||
should_run_validate = True | |||||
if should_run_validate: | |||||
self.validate_fn() | |||||
def epoch_validate(self): | |||||
if self.evaluator is not None: | |||||
should_run_validate = False | |||||
if isinstance(self.validate_every, int) and self.validate_every < 0: | |||||
validate_every = -self.validate_every | |||||
if self.cur_epoch_idx % validate_every == 0: | |||||
should_run_validate = True | |||||
if should_run_validate: | |||||
self.validate_fn() | |||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | ||||
r""" | r""" | ||||
@@ -238,7 +238,7 @@ def test_model_checkpoint_callback_2( | |||||
from fastNLP.core.callbacks.callback_events import Events | from fastNLP.core.callbacks.callback_events import Events | ||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -98,14 +98,16 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) #, ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | |||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | ||||
@pytest.mark.parametrize("validate_every", [-3]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_with_evaluator( | def test_trainer_torch_with_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | callbacks, | ||||
validate_every, | |||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -118,11 +120,11 @@ def test_trainer_torch_with_evaluator( | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
validate_every=validate_every, | |||||
n_epochs=n_epochs, | n_epochs=n_epochs, | ||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
@@ -169,4 +171,42 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | |||||
def test_trainer_validate_every( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=6, | |||||
): | |||||
def validate_every(trainer): | |||||
if trainer.global_forward_batches % 10 == 0: | |||||
print(trainer) | |||||
print("\nfastNLP test validate every.\n") | |||||
print(trainer.global_forward_batches) | |||||
return True | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
validate_dataloaders=model_and_optimizers.validate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
output_from_new_proc="all", | |||||
validate_every=validate_every | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -254,7 +254,7 @@ def test_trainer_on_exception( | |||||
): | ): | ||||
from fastNLP.core.callbacks.callback_events import Events | from fastNLP.core.callbacks.callback_events import Events | ||||
@Trainer.on(Events.ON_TRAIN_EPOCH_END) | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == cur_rank: | if trainer.driver.get_local_rank() == cur_rank: | ||||
raise NotImplementedError | raise NotImplementedError | ||||