From bf2497786190ae4d87ca3d5c64623c493515c647 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 3 May 2022 17:33:10 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9Event=EF=BC=8C=E5=88=A0?= =?UTF-8?q?=E9=99=A4EventsList,Events=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 68 ++- fastNLP/core/callbacks/__init__.py | 5 +- fastNLP/core/callbacks/callback.py | 56 +- fastNLP/core/callbacks/callback_event.py | 489 ++++++++++++++++++ fastNLP/core/callbacks/callback_events.py | 206 -------- fastNLP/core/callbacks/callback_manager.py | 6 +- fastNLP/core/collators/__init__.py | 18 +- fastNLP/core/collators/padders/__init__.py | 30 ++ fastNLP/core/collators/padders/get_padder.py | 31 +- fastNLP/core/collators/padders/raw_padder.py | 5 +- .../core/collators/padders/torch_padder.py | 6 +- fastNLP/core/collators/padders/utils.py | 6 +- fastNLP/core/controllers/__init__.py | 2 - fastNLP/core/controllers/trainer.py | 66 ++- fastNLP/core/dataset/dataset.py | 9 - .../samplers/reproducible_batch_sampler.py | 1 - fastNLP/core/samplers/reproducible_sampler.py | 2 + fastNLP/core/utils/paddle_utils.py | 5 + fastNLP/io/data_bundle.py | 39 +- tests/core/callbacks/test_callback_events.py | 2 +- .../test_checkpoint_callback_torch.py | 4 +- tests/core/collators/test_collator.py | 362 ++++++++++--- tests/core/collators/test_new_collator.py | 293 ----------- .../controllers/test_trainer_event_trigger.py | 124 ++--- .../test_trainer_wo_evaluator_torch.py | 4 +- 25 files changed, 1092 insertions(+), 747 deletions(-) create mode 100644 fastNLP/core/callbacks/callback_event.py delete mode 100644 fastNLP/core/callbacks/callback_events.py delete mode 100644 tests/core/collators/test_new_collator.py diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 5cc765b9..1501851d 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -1,4 +1,53 @@ __all__ = [ + # callbacks + 'Callback', + 'Event', + 'Filter', + 'CallbackManager', + 'CheckpointCallback', + 'choose_progress_callback', + 'ProgressCallback', + 'RichCallback', + "LRSchedCallback", + 'LoadBestModelCallback', + "EarlyStopCallback", + 'MoreEvaluateCallback', + "TorchWarmupCallback", + "TorchGradClipCallback", + + # collators + 'Collator', + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + "Padder", + "NullPadder", + "RawNumberPadder", + "RawSequencePadder", + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + "get_padded_numpy_array", + + # controllers + 'Loop', + 'EvaluateBatchLoop', + 'TrainBatchLoop', + 'Evaluator', + 'Trainer', + + # dataloaders TODO 需要把 mix_dataloader 的搞定 + + # dataset + 'DataSet', + 'FieldArray', + 'Instance', + 'ApplyResultException', + + # drivers "TorchSingleDriver", "TorchDDPDriver", "PaddleSingleDriver", @@ -7,16 +56,15 @@ __all__ = [ "JittorMPIDriver", "TorchPaddleDriver", - "paddle_to", - "get_paddle_gpu_str", - "get_paddle_device_id", - "paddle_move_data_to_device", - "torch_paddle_move_data_to_device", -] -# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; -from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.controllers.evaluator import Evaluator -from fastNLP.core.dataloaders.torch_dataloader import * + # log + "logger" +] +from .callbacks import * +from .collators import * +from .controllers import * +from .dataloaders import * +from .dataset import * from .drivers import * +from .log import * from .utils import * \ No newline at end of file diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index bbce73e0..cfda1763 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -1,7 +1,6 @@ __all__ = [ 'Callback', - 'Events', - 'EventsList', + 'Event', 'Filter', 'CallbackManager', 'CheckpointCallback', @@ -20,7 +19,7 @@ __all__ = [ from .callback import Callback -from .callback_events import EventsList, Events, Filter +from .callback_event import Event, Filter from .callback_manager import CallbackManager from .checkpoint_callback import CheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 7f0c290d..6b5c74fc 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -3,10 +3,9 @@ __all__ = [ 'Callback', ] -from typing import Union, Callable, Dict, Optional, Any +from typing import Callable, Dict, Optional -from .callback_events import Events, EventsList, Filter -from fastNLP.core.callbacks.callback_events import _SingleEventState +from .callback_event import Event, Filter class Callback: @@ -14,32 +13,35 @@ class Callback: 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; callback 调用时机顺序大概如下 Trainer.__init__(): - on_after_trainer_initialized() + on_after_trainer_initialized(trainer, driver) Trainer.run(): if num_eval_sanity_batch>0: - on_sanity_check_begin() # 如果设置了num_eval_sanity_batch - on_sanity_check_end() + on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch + on_sanity_check_end(trainer, sanity_check_res) try: - on_train_begin() + on_train_begin(trainer) while cur_epoch_idx < n_epochs: - on_train_epoch_begin() + on_train_epoch_begin(trainer) while batch_idx_in_epoch<=num_batches_per_epoch: - on_fetch_data_begin() - on_fetch_data_end() - on_train_batch_begin() - on_before_backward() - on_after_backward() - on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 - on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 - on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 - on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 - on_train_batch_end() - on_train_epoch_end() + on_fetch_data_begin(trainer) + batch = next(dataloader) + on_fetch_data_end(trainer) + on_train_batch_begin(trainer, batch, indices) + on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 + on_after_backward(trainer) + on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_train_batch_end(trainer) + on_train_epoch_end(trainer) except BaseException: - self.on_exception() + self.on_exception(trainer, exception) finally: - on_train_end() - 其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 + on_train_end(trainer) + 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ + on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 + 的时间调用。 """ def on_after_trainer_initialized(self, trainer, driver): @@ -294,18 +296,14 @@ class _CallbackWrapper(Callback): 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 这一个 callback 函数; """ - def __init__(self, event: Union[Events, EventsList], fn: Callable): + def __init__(self, event: Event, fn: Callable): r""" - :param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList'; + :param event: 具体的 callback 时机,例如 'on_train_begin' 等; :param fn: 用户定制的 callback 函数; """ self.fn = fn - if isinstance(event, EventsList): - for each_event in event: - _filter = Filter(each_event.every, each_event.once, each_event.filter_fn) - setattr(self, each_event.value, _filter(fn)) - elif isinstance(event, _SingleEventState): + if isinstance(event, Event): _filter = Filter(event.every, event.once, event.filter_fn) setattr(self, event.value, _filter(fn)) diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py new file mode 100644 index 00000000..0199ac95 --- /dev/null +++ b/fastNLP/core/callbacks/callback_event.py @@ -0,0 +1,489 @@ +from typing import Optional, Callable, Dict +from functools import wraps + + +__all__ = [ + 'Event', + 'Filter' +] + + +def check_legality(fn): + @wraps(fn) + def wrap(every=1, once=None, filter_fn=None): + if (every is None) and (once is None) and (filter_fn is None): + raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") + + if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): + raise ValueError("These three values should be only set one.") + + if (filter_fn is not None) and not callable(filter_fn): + raise TypeError("Argument event_filter should be a callable") + + if (every is not None) and not (isinstance(every, int) and every > 0): + raise ValueError("Argument every should be integer and greater than zero") + + if (once is not None) and not (isinstance(once, int) and once > 0): + raise ValueError("Argument once should be integer and positive") + return fn(every=every, once=once, filter_fn=filter_fn) + return wrap + + +class Event: + every: Optional[int] + once: Optional[int] + + def __init__(self, value: str, every: Optional[int] = 1, once: Optional[int] = False, + filter_fn: Optional[Callable] = None): + + self.every = every + self.once = once + self.filter_fn = filter_fn + self.value = value + + def __str__(self): + return "".format(self.value, self.every, self.once, + self.filter_fn) + @staticmethod + @check_legality + def on_after_trainer_initialized(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_trainer_initialized 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_sanity_check_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_sanity_check_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_sanity_check_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_sanity_check_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_epoch_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_epoch_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_epoch_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_epoch_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_fetch_data_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_fetch_data_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_fetch_data_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_fetch_data_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_batch_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_batch_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_batch_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_batch_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_exception(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_exception 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_save_model(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_save_model 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_model(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_model 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_save_checkpoint(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_save_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_checkpoint(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_checkpoint(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_backward(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_backward 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_backward(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_backward 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_optimizers_step(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_optimizers_step 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_optimizers_step(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_optimizers_step 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_zero_grad(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_zero_grad 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_zero_grad(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_zero_grad 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_evaluate_begin(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_evaluate_begin 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_evaluate_end(every=1, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_evaluate_end 时 + + 以下三个参数互斥,只能设置其中一个。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) + + +class Filter: + def __init__(self, every: Optional[int] = 1, once: Optional[bool] = False, filter_fn: Optional[Callable] = None): + r""" + 通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; + + :param every: 表示一个函数隔多少次运行一次; + :param once: 表示一个函数只运行一次; + :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 + `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; + """ + # check legality + check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) + # 设置变量,包括全局变量; + self.num_called = 0 + self.num_executed = 0 + + if every is not None: + self._every = every + self._filter = self.every_filter + elif once is not None: + self._once = once + self._filter = self.once_filter + else: + self._filter = filter_fn + + def __call__(self, fn: Callable): + + @wraps(fn) + def wrapper(*args, **kwargs) -> Callable: + self.num_called += 1 + + # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; + trainer = args[0] + if self._filter(self, trainer): + self.num_executed += 1 + return fn(*args, **kwargs) + + wrapper.__fastNLP_filter__ = self + return wrapper + + def every_filter(self, *args): + return self.num_called % self._every == 0 + + def once_filter(self, *args): + return self.num_called == self._once + + def state_dict(self) -> Dict: + r""" + 通过该函数来保存该 `Filter` 的状态; + """ + return {"num_called": self.num_called, "num_executed": self.num_executed} + + def load_state_dict(self, state: Dict): + r""" + 通过该函数来加载 `Filter` 的状态; + + :param state: 通过 `Filter.state_dict` 函数保存的状态元组; + """ + self.num_called = state["num_called"] + self.num_executed = state["num_executed"] + + + + + + + diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py deleted file mode 100644 index 7252398c..00000000 --- a/fastNLP/core/callbacks/callback_events.py +++ /dev/null @@ -1,206 +0,0 @@ -from enum import Enum, unique -from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict -from types import DynamicClassAttribute -from functools import wraps - - -__all__ = [ - 'Events', - 'EventsList', - 'Filter' -] - - -class _SingleEventState: - every: Optional[int] - once: Optional[int] - - def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, - filter_fn: Optional[Callable] = None, name: Optional[str] = None): - - # 具体的检测参数对错的逻辑放在具体的 Filter 里; - if every is None and once is None and filter_fn is None: - self.every = 1 - self.once = None - self.filter_fn = None - else: - self.every = every - self.once = once - self.filter_fn = filter_fn - - if not hasattr(self, "_value_"): - self._value_ = value - - if not hasattr(self, "_name_") and name is not None: - self._name_ = name - - # copied to be compatible to enum - @DynamicClassAttribute - def name(self) -> str: - """The name of the Enum member.""" - return self._name_ - - @DynamicClassAttribute - def value(self) -> str: - """The value of the Enum member.""" - return self._value_ - - def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): - return _SingleEventState(self.value, every, once, filter_fn, self.name) - - def __str__(self): - return "".format(self.name, self.every, self.once, - self.filter_fn) - - def __eq__(self, other) -> bool: - if isinstance(other, _SingleEventState): - return self.name == other.name - elif isinstance(other, str): - return self.name == other - else: - raise NotImplemented - - def __hash__(self): - return hash(self._name_) - - def __or__(self, other) -> "EventsList": - return EventsList() | self | other - - -class EventEnum(_SingleEventState, Enum): - pass - -@unique -class Events(EventEnum): - on_after_trainer_initialized = "on_after_trainer_initialized" - on_sanity_check_begin = "on_sanity_check_begin" - on_sanity_check_end = "on_sanity_check_end" - on_train_begin = "on_train_begin" - on_train_end = "on_train_end" - on_train_epoch_begin = "on_train_epoch_begin" - on_train_epoch_end = "on_train_epoch_end" - on_fetch_data_begin = "on_fetch_data_begin" - on_fetch_data_end = "on_fetch_data_end" - on_train_batch_begin = "on_train_batch_begin" - on_train_batch_end = "on_train_batch_end" - on_exception = "on_exception" - on_save_model = "on_save_model" - on_load_model = "on_load_model" - on_save_checkpoint = "on_save_checkpoint" - on_load_checkpoint = "on_load_checkpoint" - on_before_backward = "on_before_backward" - on_after_backward = "on_after_backward" - on_before_optimizers_step = "on_before_optimizers_step" - on_after_optimizers_step = "on_after_optimizers_step" - on_before_zero_grad = "on_before_zero_grad" - on_after_zero_grad = "on_after_zero_grad" - on_evaluate_begin = "on_evaluate_begin" - on_evaluate_end = "on_evaluate_end" - - -class EventsList: - """Collection of events stacked by operator `__or__`. - """ - - def __init__(self) -> None: - self._events = [] # type: List[Union[Events, _SingleEventState]] - - def _append(self, event: Union[Events, _SingleEventState]) -> None: - if not isinstance(event, (Events, _SingleEventState)): - raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}") - self._events.append(event) - - def __getitem__(self, item: int) -> Union[Events, _SingleEventState]: - return self._events[item] - - def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]: - return iter(self._events) - - def __len__(self) -> int: - return len(self._events) - - def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList": - self._append(event=other) - return self - - -class Filter: - def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): - r""" - 通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; - - :param every: 表示一个函数隔多少次运行一次; - :param once: 表示一个函数只在第多少次时运行一次; - :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 - `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; - """ - if (every is None) and (once is None) and (filter_fn is None): - raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") - - if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): - raise ValueError("These three values should be only set one.") - - if (filter_fn is not None) and not callable(filter_fn): - raise TypeError("Argument event_filter should be a callable") - - if (every is not None) and not (isinstance(every, int) and every > 0): - raise ValueError("Argument every should be integer and greater than zero") - - if (once is not None) and not (isinstance(once, int) and once > 0): - raise ValueError("Argument once should be integer and positive") - - # 设置变量,包括全局变量; - self.num_called = 0 - self.num_executed = 0 - - if every is not None: - self._every = every - self._filter = self.every_filter - elif once is not None: - self._once = once - self._filter = self.once_filter - else: - self._filter = filter_fn - - def __call__(self, fn: Callable): - - @wraps(fn) - def wrapper(*args, **kwargs) -> Callable: - self.num_called += 1 - - # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; - trainer = args[0] - if self._filter(self, trainer): - self.num_executed += 1 - return fn(*args, **kwargs) - - wrapper.__fastNLP_filter__ = self - return wrapper - - def every_filter(self, *args): - return self.num_called % self._every == 0 - - def once_filter(self, *args): - return self.num_called == self._once - - def state_dict(self) -> Dict: - r""" - 通过该函数来保存该 `Filter` 的状态; - """ - return {"num_called": self.num_called, "num_executed": self.num_executed} - - def load_state_dict(self, state: Dict): - r""" - 通过该函数来加载 `Filter` 的状态; - - :param state: 通过 `Filter.state_dict` 函数保存的状态元组; - """ - self.num_called = state["num_called"] - self.num_executed = state["num_executed"] - - - - - - - diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 2b8fff60..9ead6024 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -6,7 +6,7 @@ __all__ = [ 'CallbackManager' ] -from .callback_events import Events +from .callback_event import Event from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback @@ -110,7 +110,7 @@ class CallbackManager: def initialize_class_callbacks(self): r""" 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 - 一个个 callback 时机,也就是 `Events` 的类别; + 一个个 callback 时机,也就是 `Event` 的类别; 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; :param callbacks: @@ -127,7 +127,7 @@ class CallbackManager: :param callback: 一个具体的 callback 实例; """ self.all_callbacks.append(callback) - for name, member in Events.__members__.items(): + for name, member in Event.__members__.items(): _fn = getattr(callback, member.value) if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): self.callback_fns[member.value].append(_fn) diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py index 17cbb6ae..1e508689 100644 --- a/fastNLP/core/collators/__init__.py +++ b/fastNLP/core/collators/__init__.py @@ -1,4 +1,20 @@ __all__ = [ - 'Collator' + 'Collator', + + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + "Padder", + "NullPadder", + "RawNumberPadder", + "RawSequencePadder", + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + "get_padded_numpy_array", ] from .collator import Collator +from .padders import * diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py index e69de29b..09a5ca8d 100644 --- a/fastNLP/core/collators/padders/__init__.py +++ b/fastNLP/core/collators/padders/__init__.py @@ -0,0 +1,30 @@ + +__all__ = [ + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + + "Padder", + "NullPadder", + + "RawNumberPadder", + "RawSequencePadder", + + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + + "get_padded_numpy_array", +] + + +from .numpy_padder import * +from .padder import Padder, NullPadder +from .raw_padder import * +from .torch_padder import * +from .paddle_padder import * +from .utils import get_padded_numpy_array \ No newline at end of file diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 3e136d7d..8bc05ff9 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -1,8 +1,3 @@ - -from typing import Dict - - - from typing import Sequence, Any, Union, Dict from abc import ABC @@ -93,6 +88,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'paddle': return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 if backend == 'raw': @@ -103,6 +100,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'paddle': return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") if depth == 1 and shape_len != 0: if backend == 'numpy': @@ -111,6 +110,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'paddle': return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") if shape_len != 0 and depth>1: msg = "Does not support pad tensor under nested list. If you need this, please report." @@ -179,23 +180,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 return catalog - - - - -""" -from numbers import Number - -issubclass(type(3), Number) # True -issubclass(type(3.1), Number) # True -issubclass(type('3'), Number) # False -issubclass(type(True), Number) # True -issubclass(type(np.zeros(3)[0]), Number) # True -isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True -isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True -isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 -is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) -""" - - - diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py index fe0c0b14..96fb10f2 100644 --- a/fastNLP/core/collators/padders/raw_padder.py +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -1,4 +1,7 @@ - +__all__ = [ + "RawNumberPadder", + "RawSequencePadder" +] from .padder import Padder from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index 7c0eaa33..80485ff3 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -1,4 +1,8 @@ - +__all__ = [ + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder' +] from inspect import isclass import numpy as np diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py index d2d3a8e0..b322897f 100644 --- a/fastNLP/core/collators/padders/utils.py +++ b/fastNLP/core/collators/padders/utils.py @@ -1,6 +1,10 @@ +__all__ = [ + 'get_padded_numpy_array' +] + + from typing import Sequence, List -from numbers import Number import re from inspect import isclass diff --git a/fastNLP/core/controllers/__init__.py b/fastNLP/core/controllers/__init__.py index 3e93343d..ec47f254 100644 --- a/fastNLP/core/controllers/__init__.py +++ b/fastNLP/core/controllers/__init__.py @@ -2,8 +2,6 @@ __all__ = [ 'Loop', 'EvaluateBatchLoop', 'TrainBatchLoop', - 'State', - 'TrainerState', 'Evaluator', 'Trainer', ] diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 6fed9dc1..e0cf4b0d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -17,10 +17,10 @@ from .utils import State, TrainerState from .utils.utils import check_evaluate_every from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader -from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList +from fastNLP.core.callbacks import Callback, CallbackManager from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_manager import prepare_callbacks -from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.callbacks.callback_event import Event from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext @@ -398,7 +398,7 @@ class Trainer(TrainerEventTrigger): if self.cur_epoch_idx % evaluate_every == 0: self.run_evaluate() - def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): + def add_callback_fn(self, event: Event, fn: Callable): r""" 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; @@ -406,19 +406,69 @@ class Trainer(TrainerEventTrigger): :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; :param fn: 具体的 callback 函数; """ - if not isinstance(event, (_SingleEventState, EventsList)): - raise ValueError("parameter event should only be `Events` or `EventsList` type.") + if not isinstance(event, Event): + raise ValueError("parameter event should only be `Event` type.") _custom_callback = _CallbackWrapper(event, fn) self.callback_manager.dissect_one_callback(_custom_callback) @classmethod - def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None): + def on(cls, event: Event, marker: Optional[str] = None): r""" 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; + 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 + Trainer.__init__(): + on_after_trainer_initialized(trainer, driver) + Trainer.run(): + if num_eval_sanity_batch>0: + on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch + on_sanity_check_end(trainer, sanity_check_res) + try: + on_train_begin(trainer) + while cur_epoch_idx < n_epochs: + on_train_epoch_begin(trainer) + while batch_idx_in_epoch<=num_batches_per_epoch: + on_fetch_data_begin(trainer) + batch = next(dataloader) + on_fetch_data_end(trainer) + on_train_batch_begin(trainer, batch, indices) + on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 + on_after_backward(trainer) + on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_train_batch_end(trainer) + on_train_epoch_end(trainer) + except BaseException: + self.on_exception(trainer, exception) + finally: + on_train_end(trainer) + 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ + on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 + 特定的时间调用。 + + Example:: + from fastNLP import Event + @Trainer.on(Event.on_save_model()) + def do_something_1(trainer): + # do something + # 以上函数会在 Trainer 保存模型时执行。 + + @Trainer.on(Event.on_save_model(once=True)) + def do_something_2(trainer): + # do something + # 以上函数会在 Trainer 保存模型时执行,但只执行一次。 + + @Trainer.on(Event.on_train_batch_begin(every=2)) + def do_something_3(trainer, batch, indices): + # do something + # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 + 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; - :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; + :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 + 特定的参数,可以通过上述说明查阅。 :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer 实例使用; @@ -426,9 +476,9 @@ class Trainer(TrainerEventTrigger): """ def wrapper(fn: Callable) -> Callable: - cls._custom_callbacks[marker].append((event, fn)) callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] _check_valid_parameters_number(fn, callback_fn_args) + cls._custom_callbacks[marker].append((event, fn)) return fn return wrapper diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 11a2536c..3b9f027e 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -770,15 +770,6 @@ class DataSet: df = self.to_pandas() return df.to_csv(path, encoding="utf-8") - def set_ignore(self, *field_names) -> None: - """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: - """ - self.collator.set_ignore(*field_names) - @property def collator(self) -> Collator: if self._collator is None: diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 21b3f059..88fcb462 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -55,7 +55,6 @@ class ReproducibleBatchSampler: class ReproduceBatchSampler(ReproducibleBatchSampler): - # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 7edb607a..de43d7df 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet class ReproducibleSampler: """ + 可复现的 Sampler 对象。 + 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index e65cd735..e4c0a8a9 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]): else: return data.cuda(get_paddle_device_id(device)) + def get_paddle_gpu_str(device: Union[str, int]): """ 获得 `gpu:x` 类型的设备名 @@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]): return device.replace("cuda", "gpu") return f"gpu:{device}" + def get_paddle_device_id(device: Union[str, int]): """ 获得 gpu 的设备id @@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) + def is_in_paddle_dist(): """ 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 """ return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) + def is_in_fnlp_paddle_dist(): """ 判断是否处于 FastNLP 拉起的分布式进程中 """ return FASTNLP_DISTRIBUTED_CHECK in os.environ + def is_in_paddle_launch_dist(): """ 判断是否处于 launch 启动的分布式进程中 diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 2796bb69..a14439ce 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -332,13 +332,44 @@ class DataBundle: show_progress_bar=show_progress_bar, progress_desc=progress_desc) return res - def set_pad_val(self, *field_names, val=0) -> None: + def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: self + """ for _, ds in self.iter_datasets(): - ds.set_pad_val(*field_names, val=val) + ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend, + pad_fn=pad_fn) + return self - def set_input(self, *field_names) -> None: + def set_ignore(self, *field_names) -> "DataBundle": + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: self + """ for _, ds in self.iter_datasets(): - ds.set_input(*field_names) + ds.collator.set_ignore(*field_names) + return self def __repr__(self) -> str: _str = '' diff --git a/tests/core/callbacks/test_callback_events.py b/tests/core/callbacks/test_callback_events.py index 8712b469..3ff3e1aa 100644 --- a/tests/core/callbacks/test_callback_events.py +++ b/tests/core/callbacks/test_callback_events.py @@ -1,7 +1,7 @@ import pytest from functools import reduce -from fastNLP.core.callbacks.callback_events import Events, Filter +from fastNLP.core.callbacks.callback_event import Event, Filter class TestFilter: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 976b68ba..d4b49b89 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -216,9 +216,9 @@ def test_model_checkpoint_callback_2( path = Path.cwd().joinpath("test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) - from fastNLP.core.callbacks.callback_events import Events + from fastNLP.core.callbacks.callback_event import Event - @Trainer.on(Events.on_train_epoch_end) + @Trainer.on(Event.on_train_epoch_end()) def raise_exception(trainer): if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: raise NotImplementedError diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 2b56624a..87762c16 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -1,81 +1,293 @@ + +import numpy as np import pytest -from fastNLP.core.collators import AutoCollator -from fastNLP.core.collators.collator import _MultiCollator -from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.new_collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) class TestCollator: - @pytest.mark.parametrize('as_numpy', [True, False]) - def test_auto_collator(self, as_numpy): - """ - 测试auto_collator的auto_pad功能 - - :param as_numpy: - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=as_numpy) - collator.set_input('x', 'y') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - results = [] - for bucket in bucket_data: - res = collator(bucket) - assert res['x'].shape == (40, 5) - assert res['y'].shape == (40,) - results.append(res) - - def test_auto_collator_v1(self): - """ - 测试auto_collator的set_pad_val和set_pad_val功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - collator.set_input('x') - collator.set_pad_val('x', val=-1) - collator.set_as_numpy(True) - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = collator(bucket) - print(res) - - def test_multicollator(self): - """ - 测试multicollator功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - multi_collator = _MultiCollator(collator) - multi_collator.set_as_numpy(as_numpy=True) - multi_collator.set_pad_val('x', val=-1) - multi_collator.set_input('x') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = multi_collator(bucket) - print(res) + @pytest.mark.torch + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + def test_nest_ignore(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} + } + ] + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': {'int':[1, 1]}}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=None) + collator.set_ignore('str', 'int', 'lst_int') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [{'int':1}, {'int':1}]}} + pad_batch = collator(dict_batch) + findDictDiff(raw_pad_batch, pad_batch) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int') + collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) + pad_batch = collator(dict_batch) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [1, 1]}} + findDictDiff(raw_pad_batch, pad_batch) + + + + + + diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py deleted file mode 100644 index 87762c16..00000000 --- a/tests/core/collators/test_new_collator.py +++ /dev/null @@ -1,293 +0,0 @@ - -import numpy as np -import pytest - -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR - -from fastNLP.core.collators.new_collator import Collator - - -def _assert_equal(d1, d2): - try: - if 'torch' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) - assert (d1 == d2).all().item() - else: - assert all(d1 == d2) - except TypeError: - assert d1 == d2 - except ValueError: - assert (d1 == d2).all() - - -def findDictDiff(d1, d2, path=""): - for k in d1: - if k in d2: - if isinstance(d1[k], dict): - findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) - else: - _assert_equal(d1[k], d2[k]) - else: - raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) - - -def findListDiff(d1, d2): - assert len(d1)==len(d2) - for _d1, _d2 in zip(d1, d2): - if isinstance(_d1, list): - findListDiff(_d1, _d2) - else: - _assert_equal(_d1, _d2) - - -class TestCollator: - - @pytest.mark.torch - def test_run(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - collator = Collator(backend='raw') - assert raw_pad_batch == collator(dict_batch) - collator = Collator(backend='raw') - raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='numpy') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), - 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), - 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), - 'b': np.array([[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='numpy') - numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), - np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), - np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(numpy_pad_lst, collator(list_batch)) - - if _NEED_IMPORT_TORCH: - import torch - collator = Collator(backend='torch') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), - 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), - 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - 'float': torch.FloatTensor([1.1, 2.1]), - 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), - 'numpy': torch.FloatTensor([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), - 'b': torch.LongTensor( - [[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='torch') - torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), - torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), - torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(torch_pad_lst, collator(list_batch)) - - def test_pad(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 测试 set_pad - collator = Collator(backend='raw') - collator.set_pad('str', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - # 测试设置 pad 值 - collator = Collator(backend='raw') - collator.set_pad('nest_lst_int', pad_val=100) - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 设置 backend 和 type - collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - - # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - # [{'1'}, {'2'}]] - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_4', pad_val=None) - raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='raw') - collator.set_pad('_0', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_2', backend='numpy') - collator.set_pad('_4', backend='numpy', pad_val=100) - raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - # _single - collator = Collator() - collator.set_pad('_single') - findListDiff(list_batch, collator(list_batch)) - - def test_nest_ignore(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} - } - ] - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': {'int':[1, 1]}}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=None) - collator.set_ignore('str', 'int', 'lst_int') - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [{'int':1}, {'int':1}]}} - pad_batch = collator(dict_batch) - findDictDiff(raw_pad_batch, pad_batch) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int') - collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) - pad_batch = collator(dict_batch) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [1, 1]}} - findDictDiff(raw_pad_batch, pad_batch) - - - - - - diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index fab07b3c..4ed79fb8 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -7,7 +7,7 @@ from torchmetrics import Accuracy import torch.distributed as dist from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks.callback_events import Events +from fastNLP.core.callbacks.callback_event import Event from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback @@ -97,11 +97,12 @@ def test_trainer_event_trigger_1( if dist.is_initialized(): dist.destroy_process_group() - for name, member in Events.__members__.items(): - assert member.value in output[0] + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] - -@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @pytest.mark.torch @magic_argv_env_context def test_trainer_event_trigger_2( @@ -111,86 +112,86 @@ def test_trainer_event_trigger_2( n_epochs=2, ): - @Trainer.on(Events.on_after_trainer_initialized()) + @Trainer.on(Event.on_after_trainer_initialized()) def on_after_trainer_initialized(trainer, driver): print("on_after_trainer_initialized") - @Trainer.on(Events.on_sanity_check_begin()) + @Trainer.on(Event.on_sanity_check_begin()) def on_sanity_check_begin(trainer): print("on_sanity_check_begin") - @Trainer.on(Events.on_sanity_check_end()) + @Trainer.on(Event.on_sanity_check_end()) def on_sanity_check_end(trainer, sanity_check_res): print("on_sanity_check_end") - @Trainer.on(Events.on_train_begin()) + @Trainer.on(Event.on_train_begin()) def on_train_begin(trainer): print("on_train_begin") - @Trainer.on(Events.on_train_end()) + @Trainer.on(Event.on_train_end()) def on_train_end(trainer): print("on_train_end") - @Trainer.on(Events.on_train_epoch_begin()) + @Trainer.on(Event.on_train_epoch_begin()) def on_train_epoch_begin(trainer): if trainer.cur_epoch_idx >= 1: # 触发 on_exception; raise Exception print("on_train_epoch_begin") - @Trainer.on(Events.on_train_epoch_end()) + @Trainer.on(Event.on_train_epoch_end()) def on_train_epoch_end(trainer): print("on_train_epoch_end") - @Trainer.on(Events.on_fetch_data_begin()) + @Trainer.on(Event.on_fetch_data_begin()) def on_fetch_data_begin(trainer): print("on_fetch_data_begin") - @Trainer.on(Events.on_fetch_data_end()) + @Trainer.on(Event.on_fetch_data_end()) def on_fetch_data_end(trainer): print("on_fetch_data_end") - @Trainer.on(Events.on_train_batch_begin()) + @Trainer.on(Event.on_train_batch_begin()) def on_train_batch_begin(trainer, batch, indices=None): print("on_train_batch_begin") - @Trainer.on(Events.on_train_batch_end()) + @Trainer.on(Event.on_train_batch_end()) def on_train_batch_end(trainer): print("on_train_batch_end") - @Trainer.on(Events.on_exception()) + @Trainer.on(Event.on_exception()) def on_exception(trainer, exception): print("on_exception") - @Trainer.on(Events.on_before_backward()) + @Trainer.on(Event.on_before_backward()) def on_before_backward(trainer, outputs): print("on_before_backward") - @Trainer.on(Events.on_after_backward()) + @Trainer.on(Event.on_after_backward()) def on_after_backward(trainer): print("on_after_backward") - @Trainer.on(Events.on_before_optimizers_step()) + @Trainer.on(Event.on_before_optimizers_step()) def on_before_optimizers_step(trainer, optimizers): print("on_before_optimizers_step") - @Trainer.on(Events.on_after_optimizers_step()) + @Trainer.on(Event.on_after_optimizers_step()) def on_after_optimizers_step(trainer, optimizers): print("on_after_optimizers_step") - @Trainer.on(Events.on_before_zero_grad()) + @Trainer.on(Event.on_before_zero_grad()) def on_before_zero_grad(trainer, optimizers): print("on_before_zero_grad") - @Trainer.on(Events.on_after_zero_grad()) + @Trainer.on(Event.on_after_zero_grad()) def on_after_zero_grad(trainer, optimizers): print("on_after_zero_grad") - @Trainer.on(Events.on_evaluate_begin()) + @Trainer.on(Event.on_evaluate_begin()) def on_evaluate_begin(trainer): print("on_evaluate_begin") - @Trainer.on(Events.on_evaluate_end()) + @Trainer.on(Event.on_evaluate_end()) def on_evaluate_end(trainer, results): print("on_evaluate_end") @@ -211,15 +212,10 @@ def test_trainer_event_trigger_2( ) trainer.run() - - if dist.is_initialized(): - dist.destroy_process_group() - - for name, member in Events.__members__.items(): - assert member.value in output[0] - - - + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @pytest.mark.torch @@ -231,86 +227,86 @@ def test_trainer_event_trigger_3( n_epochs=2, ): - @Trainer.on(Events.on_after_trainer_initialized) + @Trainer.on(Event.on_after_trainer_initialized()) def on_after_trainer_initialized(trainer, driver): print("on_after_trainer_initialized") - @Trainer.on(Events.on_sanity_check_begin) + @Trainer.on(Event.on_sanity_check_begin()) def on_sanity_check_begin(trainer): print("on_sanity_check_begin") - @Trainer.on(Events.on_sanity_check_end) + @Trainer.on(Event.on_sanity_check_end()) def on_sanity_check_end(trainer, sanity_check_res): print("on_sanity_check_end") - @Trainer.on(Events.on_train_begin) + @Trainer.on(Event.on_train_begin()) def on_train_begin(trainer): print("on_train_begin") - @Trainer.on(Events.on_train_end) + @Trainer.on(Event.on_train_end()) def on_train_end(trainer): print("on_train_end") - @Trainer.on(Events.on_train_epoch_begin) + @Trainer.on(Event.on_train_epoch_begin()) def on_train_epoch_begin(trainer): if trainer.cur_epoch_idx >= 1: # 触发 on_exception; raise Exception print("on_train_epoch_begin") - @Trainer.on(Events.on_train_epoch_end) + @Trainer.on(Event.on_train_epoch_end()) def on_train_epoch_end(trainer): print("on_train_epoch_end") - @Trainer.on(Events.on_fetch_data_begin) + @Trainer.on(Event.on_fetch_data_begin()) def on_fetch_data_begin(trainer): print("on_fetch_data_begin") - @Trainer.on(Events.on_fetch_data_end) + @Trainer.on(Event.on_fetch_data_end()) def on_fetch_data_end(trainer): print("on_fetch_data_end") - @Trainer.on(Events.on_train_batch_begin) + @Trainer.on(Event.on_train_batch_begin()) def on_train_batch_begin(trainer, batch, indices=None): print("on_train_batch_begin") - @Trainer.on(Events.on_train_batch_end) + @Trainer.on(Event.on_train_batch_end()) def on_train_batch_end(trainer): print("on_train_batch_end") - @Trainer.on(Events.on_exception) + @Trainer.on(Event.on_exception()) def on_exception(trainer, exception): print("on_exception") - @Trainer.on(Events.on_before_backward) + @Trainer.on(Event.on_before_backward()) def on_before_backward(trainer, outputs): print("on_before_backward") - @Trainer.on(Events.on_after_backward) + @Trainer.on(Event.on_after_backward()) def on_after_backward(trainer): print("on_after_backward") - @Trainer.on(Events.on_before_optimizers_step) + @Trainer.on(Event.on_before_optimizers_step()) def on_before_optimizers_step(trainer, optimizers): print("on_before_optimizers_step") - @Trainer.on(Events.on_after_optimizers_step) + @Trainer.on(Event.on_after_optimizers_step()) def on_after_optimizers_step(trainer, optimizers): print("on_after_optimizers_step") - @Trainer.on(Events.on_before_zero_grad) + @Trainer.on(Event.on_before_zero_grad()) def on_before_zero_grad(trainer, optimizers): print("on_before_zero_grad") - @Trainer.on(Events.on_after_zero_grad) + @Trainer.on(Event.on_after_zero_grad()) def on_after_zero_grad(trainer, optimizers): print("on_after_zero_grad") - @Trainer.on(Events.on_evaluate_begin) + @Trainer.on(Event.on_evaluate_begin()) def on_evaluate_begin(trainer): print("on_evaluate_begin") - @Trainer.on(Events.on_evaluate_end) + @Trainer.on(Event.on_evaluate_end()) def on_evaluate_end(trainer, results): print("on_evaluate_end") @@ -332,19 +328,7 @@ def test_trainer_event_trigger_3( trainer.run() - if dist.is_initialized(): - dist.destroy_process_group() - - for name, member in Events.__members__.items(): - assert member.value in output[0] - - - - - - - - - - - + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 825bd425..624f80fb 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -257,9 +257,9 @@ def test_trainer_on_exception( cur_rank, n_epochs=2, ): - from fastNLP.core.callbacks.callback_events import Events + from fastNLP.core.callbacks.callback_event import Event - @Trainer.on(Events.on_train_epoch_end) + @Trainer.on(Event.on_train_epoch_end()) def raise_exception(trainer): if trainer.driver.get_local_rank() == cur_rank: raise NotImplementedError