@@ -1,4 +1,53 @@ | |||||
__all__ = [ | __all__ = [ | ||||
# callbacks | |||||
'Callback', | |||||
'Event', | |||||
'Filter', | |||||
'CallbackManager', | |||||
'CheckpointCallback', | |||||
'choose_progress_callback', | |||||
'ProgressCallback', | |||||
'RichCallback', | |||||
"LRSchedCallback", | |||||
'LoadBestModelCallback', | |||||
"EarlyStopCallback", | |||||
'MoreEvaluateCallback', | |||||
"TorchWarmupCallback", | |||||
"TorchGradClipCallback", | |||||
# collators | |||||
'Collator', | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
# controllers | |||||
'Loop', | |||||
'EvaluateBatchLoop', | |||||
'TrainBatchLoop', | |||||
'Evaluator', | |||||
'Trainer', | |||||
# dataloaders TODO 需要把 mix_dataloader 的搞定 | |||||
# dataset | |||||
'DataSet', | |||||
'FieldArray', | |||||
'Instance', | |||||
'ApplyResultException', | |||||
# drivers | |||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
@@ -7,16 +56,15 @@ __all__ = [ | |||||
"JittorMPIDriver", | "JittorMPIDriver", | ||||
"TorchPaddleDriver", | "TorchPaddleDriver", | ||||
"paddle_to", | |||||
"get_paddle_gpu_str", | |||||
"get_paddle_device_id", | |||||
"paddle_move_data_to_device", | |||||
"torch_paddle_move_data_to_device", | |||||
] | |||||
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.controllers.evaluator import Evaluator | |||||
from fastNLP.core.dataloaders.torch_dataloader import * | |||||
# log | |||||
"logger" | |||||
] | |||||
from .callbacks import * | |||||
from .collators import * | |||||
from .controllers import * | |||||
from .dataloaders import * | |||||
from .dataset import * | |||||
from .drivers import * | from .drivers import * | ||||
from .log import * | |||||
from .utils import * | from .utils import * |
@@ -1,7 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Callback', | 'Callback', | ||||
'Events', | |||||
'EventsList', | |||||
'Event', | |||||
'Filter', | 'Filter', | ||||
'CallbackManager', | 'CallbackManager', | ||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
@@ -20,7 +19,7 @@ __all__ = [ | |||||
from .callback import Callback | from .callback import Callback | ||||
from .callback_events import EventsList, Events, Filter | |||||
from .callback_event import Event, Filter | |||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import CheckpointCallback | from .checkpoint_callback import CheckpointCallback | ||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
@@ -3,10 +3,9 @@ __all__ = [ | |||||
'Callback', | 'Callback', | ||||
] | ] | ||||
from typing import Union, Callable, Dict, Optional, Any | |||||
from typing import Callable, Dict, Optional | |||||
from .callback_events import Events, EventsList, Filter | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
from .callback_event import Event, Filter | |||||
class Callback: | class Callback: | ||||
@@ -14,32 +13,35 @@ class Callback: | |||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | ||||
callback 调用时机顺序大概如下 | callback 调用时机顺序大概如下 | ||||
Trainer.__init__(): | Trainer.__init__(): | ||||
on_after_trainer_initialized() | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | Trainer.run(): | ||||
if num_eval_sanity_batch>0: | if num_eval_sanity_batch>0: | ||||
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end() | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | try: | ||||
on_train_begin() | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | while cur_epoch_idx < n_epochs: | ||||
on_train_epoch_begin() | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | while batch_idx_in_epoch<=num_batches_per_epoch: | ||||
on_fetch_data_begin() | |||||
on_fetch_data_end() | |||||
on_train_batch_begin() | |||||
on_before_backward() | |||||
on_after_backward() | |||||
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end() | |||||
on_train_epoch_end() | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | except BaseException: | ||||
self.on_exception() | |||||
self.on_exception(trainer, exception) | |||||
finally: | finally: | ||||
on_train_end() | |||||
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 | |||||
on_train_end(trainer) | |||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 | |||||
的时间调用。 | |||||
""" | """ | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
@@ -294,18 +296,14 @@ class _CallbackWrapper(Callback): | |||||
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | ||||
这一个 callback 函数; | 这一个 callback 函数; | ||||
""" | """ | ||||
def __init__(self, event: Union[Events, EventsList], fn: Callable): | |||||
def __init__(self, event: Event, fn: Callable): | |||||
r""" | r""" | ||||
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList'; | |||||
:param event: 具体的 callback 时机,例如 'on_train_begin' 等; | |||||
:param fn: 用户定制的 callback 函数; | :param fn: 用户定制的 callback 函数; | ||||
""" | """ | ||||
self.fn = fn | self.fn = fn | ||||
if isinstance(event, EventsList): | |||||
for each_event in event: | |||||
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn) | |||||
setattr(self, each_event.value, _filter(fn)) | |||||
elif isinstance(event, _SingleEventState): | |||||
if isinstance(event, Event): | |||||
_filter = Filter(event.every, event.once, event.filter_fn) | _filter = Filter(event.every, event.once, event.filter_fn) | ||||
setattr(self, event.value, _filter(fn)) | setattr(self, event.value, _filter(fn)) | ||||
@@ -0,0 +1,489 @@ | |||||
from typing import Optional, Callable, Dict | |||||
from functools import wraps | |||||
__all__ = [ | |||||
'Event', | |||||
'Filter' | |||||
] | |||||
def check_legality(fn): | |||||
@wraps(fn) | |||||
def wrap(every=1, once=None, filter_fn=None): | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") | |||||
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): | |||||
raise ValueError("These three values should be only set one.") | |||||
if (filter_fn is not None) and not callable(filter_fn): | |||||
raise TypeError("Argument event_filter should be a callable") | |||||
if (every is not None) and not (isinstance(every, int) and every > 0): | |||||
raise ValueError("Argument every should be integer and greater than zero") | |||||
if (once is not None) and not (isinstance(once, int) and once > 0): | |||||
raise ValueError("Argument once should be integer and positive") | |||||
return fn(every=every, once=once, filter_fn=filter_fn) | |||||
return wrap | |||||
class Event: | |||||
every: Optional[int] | |||||
once: Optional[int] | |||||
def __init__(self, value: str, every: Optional[int] = 1, once: Optional[int] = False, | |||||
filter_fn: Optional[Callable] = None): | |||||
self.every = every | |||||
self.once = once | |||||
self.filter_fn = filter_fn | |||||
self.value = value | |||||
def __str__(self): | |||||
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once, | |||||
self.filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_trainer_initialized(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_trainer_initialized 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_sanity_check_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_sanity_check_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_sanity_check_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_sanity_check_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_epoch_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_epoch_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_epoch_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_epoch_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_fetch_data_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_fetch_data_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_fetch_data_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_fetch_data_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_batch_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_batch_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_batch_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_batch_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_exception(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_exception 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_save_model(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_save_model 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_model(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_model 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_save_checkpoint(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_save_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_checkpoint(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_checkpoint(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_backward(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_backward 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_backward(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_backward 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_optimizers_step(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_optimizers_step 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_optimizers_step(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_optimizers_step 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_zero_grad(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_zero_grad 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_zero_grad(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_zero_grad 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_evaluate_begin(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_evaluate_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_evaluate_end(every=1, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_evaluate_end 时 | |||||
以下三个参数互斥,只能设置其中一个。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) | |||||
class Filter: | |||||
def __init__(self, every: Optional[int] = 1, once: Optional[bool] = False, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
# check legality | |||||
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) | |||||
# 设置变量,包括全局变量; | |||||
self.num_called = 0 | |||||
self.num_executed = 0 | |||||
if every is not None: | |||||
self._every = every | |||||
self._filter = self.every_filter | |||||
elif once is not None: | |||||
self._once = once | |||||
self._filter = self.once_filter | |||||
else: | |||||
self._filter = filter_fn | |||||
def __call__(self, fn: Callable): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs) -> Callable: | |||||
self.num_called += 1 | |||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | |||||
return fn(*args, **kwargs) | |||||
wrapper.__fastNLP_filter__ = self | |||||
return wrapper | |||||
def every_filter(self, *args): | |||||
return self.num_called % self._every == 0 | |||||
def once_filter(self, *args): | |||||
return self.num_called == self._once | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
通过该函数来保存该 `Filter` 的状态; | |||||
""" | |||||
return {"num_called": self.num_called, "num_executed": self.num_executed} | |||||
def load_state_dict(self, state: Dict): | |||||
r""" | |||||
通过该函数来加载 `Filter` 的状态; | |||||
:param state: 通过 `Filter.state_dict` 函数保存的状态元组; | |||||
""" | |||||
self.num_called = state["num_called"] | |||||
self.num_executed = state["num_executed"] | |||||
@@ -1,206 +0,0 @@ | |||||
from enum import Enum, unique | |||||
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict | |||||
from types import DynamicClassAttribute | |||||
from functools import wraps | |||||
__all__ = [ | |||||
'Events', | |||||
'EventsList', | |||||
'Filter' | |||||
] | |||||
class _SingleEventState: | |||||
every: Optional[int] | |||||
once: Optional[int] | |||||
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | |||||
filter_fn: Optional[Callable] = None, name: Optional[str] = None): | |||||
# 具体的检测参数对错的逻辑放在具体的 Filter 里; | |||||
if every is None and once is None and filter_fn is None: | |||||
self.every = 1 | |||||
self.once = None | |||||
self.filter_fn = None | |||||
else: | |||||
self.every = every | |||||
self.once = once | |||||
self.filter_fn = filter_fn | |||||
if not hasattr(self, "_value_"): | |||||
self._value_ = value | |||||
if not hasattr(self, "_name_") and name is not None: | |||||
self._name_ = name | |||||
# copied to be compatible to enum | |||||
@DynamicClassAttribute | |||||
def name(self) -> str: | |||||
"""The name of the Enum member.""" | |||||
return self._name_ | |||||
@DynamicClassAttribute | |||||
def value(self) -> str: | |||||
"""The value of the Enum member.""" | |||||
return self._value_ | |||||
def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
return _SingleEventState(self.value, every, once, filter_fn, self.name) | |||||
def __str__(self): | |||||
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once, | |||||
self.filter_fn) | |||||
def __eq__(self, other) -> bool: | |||||
if isinstance(other, _SingleEventState): | |||||
return self.name == other.name | |||||
elif isinstance(other, str): | |||||
return self.name == other | |||||
else: | |||||
raise NotImplemented | |||||
def __hash__(self): | |||||
return hash(self._name_) | |||||
def __or__(self, other) -> "EventsList": | |||||
return EventsList() | self | other | |||||
class EventEnum(_SingleEventState, Enum): | |||||
pass | |||||
@unique | |||||
class Events(EventEnum): | |||||
on_after_trainer_initialized = "on_after_trainer_initialized" | |||||
on_sanity_check_begin = "on_sanity_check_begin" | |||||
on_sanity_check_end = "on_sanity_check_end" | |||||
on_train_begin = "on_train_begin" | |||||
on_train_end = "on_train_end" | |||||
on_train_epoch_begin = "on_train_epoch_begin" | |||||
on_train_epoch_end = "on_train_epoch_end" | |||||
on_fetch_data_begin = "on_fetch_data_begin" | |||||
on_fetch_data_end = "on_fetch_data_end" | |||||
on_train_batch_begin = "on_train_batch_begin" | |||||
on_train_batch_end = "on_train_batch_end" | |||||
on_exception = "on_exception" | |||||
on_save_model = "on_save_model" | |||||
on_load_model = "on_load_model" | |||||
on_save_checkpoint = "on_save_checkpoint" | |||||
on_load_checkpoint = "on_load_checkpoint" | |||||
on_before_backward = "on_before_backward" | |||||
on_after_backward = "on_after_backward" | |||||
on_before_optimizers_step = "on_before_optimizers_step" | |||||
on_after_optimizers_step = "on_after_optimizers_step" | |||||
on_before_zero_grad = "on_before_zero_grad" | |||||
on_after_zero_grad = "on_after_zero_grad" | |||||
on_evaluate_begin = "on_evaluate_begin" | |||||
on_evaluate_end = "on_evaluate_end" | |||||
class EventsList: | |||||
"""Collection of events stacked by operator `__or__`. | |||||
""" | |||||
def __init__(self) -> None: | |||||
self._events = [] # type: List[Union[Events, _SingleEventState]] | |||||
def _append(self, event: Union[Events, _SingleEventState]) -> None: | |||||
if not isinstance(event, (Events, _SingleEventState)): | |||||
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}") | |||||
self._events.append(event) | |||||
def __getitem__(self, item: int) -> Union[Events, _SingleEventState]: | |||||
return self._events[item] | |||||
def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]: | |||||
return iter(self._events) | |||||
def __len__(self) -> int: | |||||
return len(self._events) | |||||
def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList": | |||||
self._append(event=other) | |||||
return self | |||||
class Filter: | |||||
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只在第多少次时运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") | |||||
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): | |||||
raise ValueError("These three values should be only set one.") | |||||
if (filter_fn is not None) and not callable(filter_fn): | |||||
raise TypeError("Argument event_filter should be a callable") | |||||
if (every is not None) and not (isinstance(every, int) and every > 0): | |||||
raise ValueError("Argument every should be integer and greater than zero") | |||||
if (once is not None) and not (isinstance(once, int) and once > 0): | |||||
raise ValueError("Argument once should be integer and positive") | |||||
# 设置变量,包括全局变量; | |||||
self.num_called = 0 | |||||
self.num_executed = 0 | |||||
if every is not None: | |||||
self._every = every | |||||
self._filter = self.every_filter | |||||
elif once is not None: | |||||
self._once = once | |||||
self._filter = self.once_filter | |||||
else: | |||||
self._filter = filter_fn | |||||
def __call__(self, fn: Callable): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs) -> Callable: | |||||
self.num_called += 1 | |||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | |||||
return fn(*args, **kwargs) | |||||
wrapper.__fastNLP_filter__ = self | |||||
return wrapper | |||||
def every_filter(self, *args): | |||||
return self.num_called % self._every == 0 | |||||
def once_filter(self, *args): | |||||
return self.num_called == self._once | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
通过该函数来保存该 `Filter` 的状态; | |||||
""" | |||||
return {"num_called": self.num_called, "num_executed": self.num_executed} | |||||
def load_state_dict(self, state: Dict): | |||||
r""" | |||||
通过该函数来加载 `Filter` 的状态; | |||||
:param state: 通过 `Filter.state_dict` 函数保存的状态元组; | |||||
""" | |||||
self.num_called = state["num_called"] | |||||
self.num_executed = state["num_executed"] | |||||
@@ -6,7 +6,7 @@ __all__ = [ | |||||
'CallbackManager' | 'CallbackManager' | ||||
] | ] | ||||
from .callback_events import Events | |||||
from .callback_event import Event | |||||
from .callback import Callback | from .callback import Callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
@@ -110,7 +110,7 @@ class CallbackManager: | |||||
def initialize_class_callbacks(self): | def initialize_class_callbacks(self): | ||||
r""" | r""" | ||||
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | ||||
一个个 callback 时机,也就是 `Events` 的类别; | |||||
一个个 callback 时机,也就是 `Event` 的类别; | |||||
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | ||||
:param callbacks: | :param callbacks: | ||||
@@ -127,7 +127,7 @@ class CallbackManager: | |||||
:param callback: 一个具体的 callback 实例; | :param callback: 一个具体的 callback 实例; | ||||
""" | """ | ||||
self.all_callbacks.append(callback) | self.all_callbacks.append(callback) | ||||
for name, member in Events.__members__.items(): | |||||
for name, member in Event.__members__.items(): | |||||
_fn = getattr(callback, member.value) | _fn = getattr(callback, member.value) | ||||
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): | if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): | ||||
self.callback_fns[member.value].append(_fn) | self.callback_fns[member.value].append(_fn) | ||||
@@ -1,4 +1,20 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Collator' | |||||
'Collator', | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
] | ] | ||||
from .collator import Collator | from .collator import Collator | ||||
from .padders import * |
@@ -0,0 +1,30 @@ | |||||
__all__ = [ | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
] | |||||
from .numpy_padder import * | |||||
from .padder import Padder, NullPadder | |||||
from .raw_padder import * | |||||
from .torch_padder import * | |||||
from .paddle_padder import * | |||||
from .utils import get_padded_numpy_array |
@@ -1,8 +1,3 @@ | |||||
from typing import Dict | |||||
from typing import Sequence, Any, Union, Dict | from typing import Sequence, Any, Union, Dict | ||||
from abc import ABC | from abc import ABC | ||||
@@ -93,6 +88,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | ||||
if backend == 'raw': | if backend == 'raw': | ||||
@@ -103,6 +100,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | |||||
if depth == 1 and shape_len != 0: | if depth == 1 and shape_len != 0: | ||||
if backend == 'numpy': | if backend == 'numpy': | ||||
@@ -111,6 +110,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | |||||
if shape_len != 0 and depth>1: | if shape_len != 0 and depth>1: | ||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
@@ -179,23 +180,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: | |||||
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | ||||
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | ||||
return catalog | return catalog | ||||
""" | |||||
from numbers import Number | |||||
issubclass(type(3), Number) # True | |||||
issubclass(type(3.1), Number) # True | |||||
issubclass(type('3'), Number) # False | |||||
issubclass(type(True), Number) # True | |||||
issubclass(type(np.zeros(3)[0]), Number) # True | |||||
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 | |||||
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) | |||||
""" | |||||
@@ -1,4 +1,7 @@ | |||||
__all__ = [ | |||||
"RawNumberPadder", | |||||
"RawSequencePadder" | |||||
] | |||||
from .padder import Padder | from .padder import Padder | ||||
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | ||||
@@ -1,4 +1,8 @@ | |||||
__all__ = [ | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder' | |||||
] | |||||
from inspect import isclass | from inspect import isclass | ||||
import numpy as np | import numpy as np | ||||
@@ -1,6 +1,10 @@ | |||||
__all__ = [ | |||||
'get_padded_numpy_array' | |||||
] | |||||
from typing import Sequence, List | from typing import Sequence, List | ||||
from numbers import Number | |||||
import re | import re | ||||
from inspect import isclass | from inspect import isclass | ||||
@@ -2,8 +2,6 @@ __all__ = [ | |||||
'Loop', | 'Loop', | ||||
'EvaluateBatchLoop', | 'EvaluateBatchLoop', | ||||
'TrainBatchLoop', | 'TrainBatchLoop', | ||||
'State', | |||||
'TrainerState', | |||||
'Evaluator', | 'Evaluator', | ||||
'Trainer', | 'Trainer', | ||||
] | ] | ||||
@@ -17,10 +17,10 @@ from .utils import State, TrainerState | |||||
from .utils.utils import check_evaluate_every | from .utils.utils import check_evaluate_every | ||||
from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||||
from fastNLP.core.callbacks import Callback, CallbackManager | |||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_manager import prepare_callbacks | from fastNLP.core.callbacks.callback_manager import prepare_callbacks | ||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
@@ -398,7 +398,7 @@ class Trainer(TrainerEventTrigger): | |||||
if self.cur_epoch_idx % evaluate_every == 0: | if self.cur_epoch_idx % evaluate_every == 0: | ||||
self.run_evaluate() | self.run_evaluate() | ||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||||
def add_callback_fn(self, event: Event, fn: Callable): | |||||
r""" | r""" | ||||
在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; | 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; | ||||
这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; | 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; | ||||
@@ -406,19 +406,69 @@ class Trainer(TrainerEventTrigger): | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | ||||
:param fn: 具体的 callback 函数; | :param fn: 具体的 callback 函数; | ||||
""" | """ | ||||
if not isinstance(event, (_SingleEventState, EventsList)): | |||||
raise ValueError("parameter event should only be `Events` or `EventsList` type.") | |||||
if not isinstance(event, Event): | |||||
raise ValueError("parameter event should only be `Event` type.") | |||||
_custom_callback = _CallbackWrapper(event, fn) | _custom_callback = _CallbackWrapper(event, fn) | ||||
self.callback_manager.dissect_one_callback(_custom_callback) | self.callback_manager.dissect_one_callback(_custom_callback) | ||||
@classmethod | @classmethod | ||||
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None): | |||||
def on(cls, event: Event, marker: Optional[str] = None): | |||||
r""" | r""" | ||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | ||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 | |||||
Trainer.__init__(): | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | |||||
if num_eval_sanity_batch>0: | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | |||||
self.on_exception(trainer, exception) | |||||
finally: | |||||
on_train_end(trainer) | |||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 | |||||
特定的时间调用。 | |||||
Example:: | |||||
from fastNLP import Event | |||||
@Trainer.on(Event.on_save_model()) | |||||
def do_something_1(trainer): | |||||
# do something | |||||
# 以上函数会在 Trainer 保存模型时执行。 | |||||
@Trainer.on(Event.on_save_model(once=True)) | |||||
def do_something_2(trainer): | |||||
# do something | |||||
# 以上函数会在 Trainer 保存模型时执行,但只执行一次。 | |||||
@Trainer.on(Event.on_train_batch_begin(every=2)) | |||||
def do_something_3(trainer, batch, indices): | |||||
# do something | |||||
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 | |||||
注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; | 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; | ||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 | |||||
特定的参数,可以通过上述说明查阅。 | |||||
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, | :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, | ||||
表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer | 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer | ||||
实例使用; | 实例使用; | ||||
@@ -426,9 +476,9 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
def wrapper(fn: Callable) -> Callable: | def wrapper(fn: Callable) -> Callable: | ||||
cls._custom_callbacks[marker].append((event, fn)) | |||||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | ||||
_check_valid_parameters_number(fn, callback_fn_args) | _check_valid_parameters_number(fn, callback_fn_args) | ||||
cls._custom_callbacks[marker].append((event, fn)) | |||||
return fn | return fn | ||||
return wrapper | return wrapper | ||||
@@ -770,15 +770,6 @@ class DataSet: | |||||
df = self.to_pandas() | df = self.to_pandas() | ||||
return df.to_csv(path, encoding="utf-8") | return df.to_csv(path, encoding="utf-8") | ||||
def set_ignore(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
self.collator.set_ignore(*field_names) | |||||
@property | @property | ||||
def collator(self) -> Collator: | def collator(self) -> Collator: | ||||
if self._collator is None: | if self._collator is None: | ||||
@@ -55,7 +55,6 @@ class ReproducibleBatchSampler: | |||||
class ReproduceBatchSampler(ReproducibleBatchSampler): | class ReproduceBatchSampler(ReproducibleBatchSampler): | ||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | ||||
@@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
""" | """ | ||||
可复现的 Sampler 对象。 | |||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | ||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | ||||
@@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]): | |||||
else: | else: | ||||
return data.cuda(get_paddle_device_id(device)) | return data.cuda(get_paddle_device_id(device)) | ||||
def get_paddle_gpu_str(device: Union[str, int]): | def get_paddle_gpu_str(device: Union[str, int]): | ||||
""" | """ | ||||
获得 `gpu:x` 类型的设备名 | 获得 `gpu:x` 类型的设备名 | ||||
@@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]): | |||||
return device.replace("cuda", "gpu") | return device.replace("cuda", "gpu") | ||||
return f"gpu:{device}" | return f"gpu:{device}" | ||||
def get_paddle_device_id(device: Union[str, int]): | def get_paddle_device_id(device: Union[str, int]): | ||||
""" | """ | ||||
获得 gpu 的设备id | 获得 gpu 的设备id | ||||
@@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||||
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | ||||
def is_in_paddle_dist(): | def is_in_paddle_dist(): | ||||
""" | """ | ||||
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | ||||
""" | """ | ||||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | ||||
def is_in_fnlp_paddle_dist(): | def is_in_fnlp_paddle_dist(): | ||||
""" | """ | ||||
判断是否处于 FastNLP 拉起的分布式进程中 | 判断是否处于 FastNLP 拉起的分布式进程中 | ||||
""" | """ | ||||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | return FASTNLP_DISTRIBUTED_CHECK in os.environ | ||||
def is_in_paddle_launch_dist(): | def is_in_paddle_launch_dist(): | ||||
""" | """ | ||||
判断是否处于 launch 启动的分布式进程中 | 判断是否处于 launch 启动的分布式进程中 | ||||
@@ -332,13 +332,44 @@ class DataBundle: | |||||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | show_progress_bar=show_progress_bar, progress_desc=progress_desc) | ||||
return res | return res | ||||
def set_pad_val(self, *field_names, val=0) -> None: | |||||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: self | |||||
""" | |||||
for _, ds in self.iter_datasets(): | for _, ds in self.iter_datasets(): | ||||
ds.set_pad_val(*field_names, val=val) | |||||
ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend, | |||||
pad_fn=pad_fn) | |||||
return self | |||||
def set_input(self, *field_names) -> None: | |||||
def set_ignore(self, *field_names) -> "DataBundle": | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: self | |||||
""" | |||||
for _, ds in self.iter_datasets(): | for _, ds in self.iter_datasets(): | ||||
ds.set_input(*field_names) | |||||
ds.collator.set_ignore(*field_names) | |||||
return self | |||||
def __repr__(self) -> str: | def __repr__(self) -> str: | ||||
_str = '' | _str = '' | ||||
@@ -1,11 +1,11 @@ | |||||
import pytest | import pytest | ||||
from functools import reduce | from functools import reduce | ||||
from fastNLP.core.callbacks.callback_events import Events, Filter | |||||
from fastNLP.core.callbacks.callback_event import Event, Filter | |||||
class TestFilter: | |||||
class TestFilter: | |||||
def test_params_check(self): | def test_params_check(self): | ||||
# 顺利通过 | # 顺利通过 | ||||
_filter1 = Filter(every=10) | _filter1 = Filter(every=10) | ||||
@@ -80,35 +80,6 @@ class TestFilter: | |||||
_res.append(cu_res) | _res.append(cu_res) | ||||
assert _res == [9] | assert _res == [9] | ||||
def test_filter_fn(self): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
def test_extract_filter_from_fn(self): | def test_extract_filter_from_fn(self): | ||||
@Filter(every=10) | @Filter(every=10) | ||||
@@ -155,3 +126,119 @@ class TestFilter: | |||||
assert _res == [w - 1 for w in range(60, 101, 10)] | assert _res == [w - 1 for w in range(60, 101, 10)] | ||||
@pytest.mark.torch | |||||
def test_filter_fn_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
class TestCallbackEvents: | |||||
def test_every(self): | |||||
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | |||||
event_state = Events.on_train_begin() # 什么都不输入是应当默认 every=1; | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
event_state = Events.on_train_begin(every=10) | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 101, 10)] | |||||
def test_once(self): | |||||
event_state = Events.on_train_begin(once=10) | |||||
@Filter(once=event_state.once) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
@pytest.mark.torch | |||||
def test_callback_events_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
event_state = Events.on_train_begin(filter_fn=filter_fn) | |||||
@Filter(filter_fn=event_state.filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
@@ -218,9 +218,9 @@ def test_model_checkpoint_callback_2( | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | path = Path.cwd().joinpath("test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | ||||
@@ -65,10 +65,9 @@ def model_and_optimizers(): | |||||
return trainer_params | return trainer_params | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger_1( | def test_trainer_event_trigger_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -100,12 +99,13 @@ def test_trainer_event_trigger_1( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger_2( | def test_trainer_event_trigger_2( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -114,86 +114,86 @@ def test_trainer_event_trigger_2( | |||||
n_epochs=2, | n_epochs=2, | ||||
): | ): | ||||
@Trainer.on(Events.on_after_trainer_initialized()) | |||||
@Trainer.on(Event.on_after_trainer_initialized()) | |||||
def on_after_trainer_initialized(trainer, driver): | def on_after_trainer_initialized(trainer, driver): | ||||
print("on_after_trainer_initialized") | print("on_after_trainer_initialized") | ||||
@Trainer.on(Events.on_sanity_check_begin()) | |||||
@Trainer.on(Event.on_sanity_check_begin()) | |||||
def on_sanity_check_begin(trainer): | def on_sanity_check_begin(trainer): | ||||
print("on_sanity_check_begin") | print("on_sanity_check_begin") | ||||
@Trainer.on(Events.on_sanity_check_end()) | |||||
@Trainer.on(Event.on_sanity_check_end()) | |||||
def on_sanity_check_end(trainer, sanity_check_res): | def on_sanity_check_end(trainer, sanity_check_res): | ||||
print("on_sanity_check_end") | print("on_sanity_check_end") | ||||
@Trainer.on(Events.on_train_begin()) | |||||
@Trainer.on(Event.on_train_begin()) | |||||
def on_train_begin(trainer): | def on_train_begin(trainer): | ||||
print("on_train_begin") | print("on_train_begin") | ||||
@Trainer.on(Events.on_train_end()) | |||||
@Trainer.on(Event.on_train_end()) | |||||
def on_train_end(trainer): | def on_train_end(trainer): | ||||
print("on_train_end") | print("on_train_end") | ||||
@Trainer.on(Events.on_train_epoch_begin()) | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def on_train_epoch_begin(trainer): | def on_train_epoch_begin(trainer): | ||||
if trainer.cur_epoch_idx >= 1: | if trainer.cur_epoch_idx >= 1: | ||||
# 触发 on_exception; | # 触发 on_exception; | ||||
raise Exception | raise Exception | ||||
print("on_train_epoch_begin") | print("on_train_epoch_begin") | ||||
@Trainer.on(Events.on_train_epoch_end()) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def on_train_epoch_end(trainer): | def on_train_epoch_end(trainer): | ||||
print("on_train_epoch_end") | print("on_train_epoch_end") | ||||
@Trainer.on(Events.on_fetch_data_begin()) | |||||
@Trainer.on(Event.on_fetch_data_begin()) | |||||
def on_fetch_data_begin(trainer): | def on_fetch_data_begin(trainer): | ||||
print("on_fetch_data_begin") | print("on_fetch_data_begin") | ||||
@Trainer.on(Events.on_fetch_data_end()) | |||||
@Trainer.on(Event.on_fetch_data_end()) | |||||
def on_fetch_data_end(trainer): | def on_fetch_data_end(trainer): | ||||
print("on_fetch_data_end") | print("on_fetch_data_end") | ||||
@Trainer.on(Events.on_train_batch_begin()) | |||||
@Trainer.on(Event.on_train_batch_begin()) | |||||
def on_train_batch_begin(trainer, batch, indices=None): | def on_train_batch_begin(trainer, batch, indices=None): | ||||
print("on_train_batch_begin") | print("on_train_batch_begin") | ||||
@Trainer.on(Events.on_train_batch_end()) | |||||
@Trainer.on(Event.on_train_batch_end()) | |||||
def on_train_batch_end(trainer): | def on_train_batch_end(trainer): | ||||
print("on_train_batch_end") | print("on_train_batch_end") | ||||
@Trainer.on(Events.on_exception()) | |||||
@Trainer.on(Event.on_exception()) | |||||
def on_exception(trainer, exception): | def on_exception(trainer, exception): | ||||
print("on_exception") | print("on_exception") | ||||
@Trainer.on(Events.on_before_backward()) | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def on_before_backward(trainer, outputs): | def on_before_backward(trainer, outputs): | ||||
print("on_before_backward") | print("on_before_backward") | ||||
@Trainer.on(Events.on_after_backward()) | |||||
@Trainer.on(Event.on_after_backward()) | |||||
def on_after_backward(trainer): | def on_after_backward(trainer): | ||||
print("on_after_backward") | print("on_after_backward") | ||||
@Trainer.on(Events.on_before_optimizers_step()) | |||||
@Trainer.on(Event.on_before_optimizers_step()) | |||||
def on_before_optimizers_step(trainer, optimizers): | def on_before_optimizers_step(trainer, optimizers): | ||||
print("on_before_optimizers_step") | print("on_before_optimizers_step") | ||||
@Trainer.on(Events.on_after_optimizers_step()) | |||||
@Trainer.on(Event.on_after_optimizers_step()) | |||||
def on_after_optimizers_step(trainer, optimizers): | def on_after_optimizers_step(trainer, optimizers): | ||||
print("on_after_optimizers_step") | print("on_after_optimizers_step") | ||||
@Trainer.on(Events.on_before_zero_grad()) | |||||
@Trainer.on(Event.on_before_zero_grad()) | |||||
def on_before_zero_grad(trainer, optimizers): | def on_before_zero_grad(trainer, optimizers): | ||||
print("on_before_zero_grad") | print("on_before_zero_grad") | ||||
@Trainer.on(Events.on_after_zero_grad()) | |||||
@Trainer.on(Event.on_after_zero_grad()) | |||||
def on_after_zero_grad(trainer, optimizers): | def on_after_zero_grad(trainer, optimizers): | ||||
print("on_after_zero_grad") | print("on_after_zero_grad") | ||||
@Trainer.on(Events.on_evaluate_begin()) | |||||
@Trainer.on(Event.on_evaluate_begin()) | |||||
def on_evaluate_begin(trainer): | def on_evaluate_begin(trainer): | ||||
print("on_evaluate_begin") | print("on_evaluate_begin") | ||||
@Trainer.on(Events.on_evaluate_end()) | |||||
@Trainer.on(Event.on_evaluate_end()) | |||||
def on_evaluate_end(trainer, results): | def on_evaluate_end(trainer, results): | ||||
print("on_evaluate_end") | print("on_evaluate_end") | ||||
@@ -218,13 +218,13 @@ def test_trainer_event_trigger_2( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger_3( | def test_trainer_event_trigger_3( | ||||
@@ -233,113 +233,74 @@ def test_trainer_event_trigger_3( | |||||
device, | device, | ||||
n_epochs=2, | n_epochs=2, | ||||
): | ): | ||||
import re | |||||
@Trainer.on(Events.on_after_trainer_initialized) | |||||
def on_after_trainer_initialized(trainer, driver): | |||||
print("on_after_trainer_initialized") | |||||
once_message_1 = "This message should be typed 1 times." | |||||
once_message_2 = "test_filter_fn" | |||||
once_message_3 = "once message 3" | |||||
twice_message = "twice message hei hei" | |||||
@Trainer.on(Events.on_sanity_check_begin) | |||||
def on_sanity_check_begin(trainer): | |||||
print("on_sanity_check_begin") | |||||
@Trainer.on(Event.on_train_epoch_begin(every=2)) | |||||
def train_epoch_begin_1(trainer): | |||||
print(once_message_1) | |||||
@Trainer.on(Events.on_sanity_check_end) | |||||
def on_sanity_check_end(trainer, sanity_check_res): | |||||
print("on_sanity_check_end") | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def train_epoch_begin_2(trainer): | |||||
print(twice_message) | |||||
@Trainer.on(Events.on_train_begin) | |||||
def on_train_begin(trainer): | |||||
print("on_train_begin") | |||||
@Trainer.on(Event.on_train_epoch_begin(once=2)) | |||||
def train_epoch_begin_3(trainer): | |||||
print(once_message_3) | |||||
@Trainer.on(Events.on_train_end) | |||||
def on_train_end(trainer): | |||||
print("on_train_end") | |||||
def filter_fn(filter, trainer): | |||||
if trainer.cur_epoch_idx == 1: | |||||
return True | |||||
else: | |||||
return False | |||||
@Trainer.on(Events.on_train_epoch_begin) | |||||
def on_train_epoch_begin(trainer): | |||||
if trainer.cur_epoch_idx >= 1: | |||||
# 触发 on_exception; | |||||
raise Exception | |||||
print("on_train_epoch_begin") | |||||
@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn)) | |||||
def test_filter_fn(trainer): | |||||
print(once_message_2) | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def on_train_epoch_end(trainer): | |||||
print("on_train_epoch_end") | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
@Trainer.on(Events.on_fetch_data_begin) | |||||
def on_fetch_data_begin(trainer): | |||||
print("on_fetch_data_begin") | |||||
n_epochs=n_epochs, | |||||
) | |||||
@Trainer.on(Events.on_fetch_data_end) | |||||
def on_fetch_data_end(trainer): | |||||
print("on_fetch_data_end") | |||||
trainer.run() | |||||
@Trainer.on(Events.on_train_batch_begin) | |||||
def on_train_batch_begin(trainer, batch, indices=None): | |||||
print("on_train_batch_begin") | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@Trainer.on(Events.on_train_batch_end) | |||||
def on_train_batch_end(trainer): | |||||
print("on_train_batch_end") | |||||
@Trainer.on(Events.on_exception) | |||||
def on_exception(trainer, exception): | |||||
print("on_exception") | |||||
once_pattern_1 = re.compile(once_message_1) | |||||
once_pattern_2 = re.compile(once_message_2) | |||||
once_pattern_3 = re.compile(once_message_3) | |||||
twice_pattern = re.compile(twice_message) | |||||
@Trainer.on(Events.on_before_backward) | |||||
def on_before_backward(trainer, outputs): | |||||
print("on_before_backward") | |||||
once_res_1 = once_pattern_1.findall(output[0]) | |||||
assert len(once_res_1) == 1 | |||||
once_res_2 = once_pattern_2.findall(output[0]) | |||||
assert len(once_res_2) == 1 | |||||
once_res_3 = once_pattern_3.findall(output[0]) | |||||
assert len(once_res_3) == 1 | |||||
twice_res = twice_pattern.findall(output[0]) | |||||
assert len(twice_res) == 2 | |||||
@Trainer.on(Events.on_after_backward) | |||||
def on_after_backward(trainer): | |||||
print("on_after_backward") | |||||
@Trainer.on(Events.on_before_optimizers_step) | |||||
def on_before_optimizers_step(trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
@Trainer.on(Events.on_after_optimizers_step) | |||||
def on_after_optimizers_step(trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
@Trainer.on(Events.on_before_zero_grad) | |||||
def on_before_zero_grad(trainer, optimizers): | |||||
print("on_before_zero_grad") | |||||
@Trainer.on(Events.on_after_zero_grad) | |||||
def on_after_zero_grad(trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
@Trainer.on(Events.on_evaluate_begin) | |||||
def on_evaluate_begin(trainer): | |||||
print("on_evaluate_begin") | |||||
@Trainer.on(Events.on_evaluate_end) | |||||
def on_evaluate_end(trainer, results): | |||||
print("on_evaluate_end") | |||||
with pytest.raises(Exception): | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
@@ -260,9 +260,9 @@ def test_trainer_on_exception( | |||||
cur_rank, | cur_rank, | ||||
n_epochs=2, | n_epochs=2, | ||||
): | ): | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == cur_rank: | if trainer.driver.get_local_rank() == cur_rank: | ||||
raise NotImplementedError | raise NotImplementedError | ||||