@@ -1,4 +1,53 @@ | |||||
__all__ = [ | __all__ = [ | ||||
# callbacks | |||||
'Callback', | |||||
'Event', | |||||
'Filter', | |||||
'CallbackManager', | |||||
'CheckpointCallback', | |||||
'choose_progress_callback', | |||||
'ProgressCallback', | |||||
'RichCallback', | |||||
"LRSchedCallback", | |||||
'LoadBestModelCallback', | |||||
"EarlyStopCallback", | |||||
'MoreEvaluateCallback', | |||||
"TorchWarmupCallback", | |||||
"TorchGradClipCallback", | |||||
# collators | |||||
'Collator', | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
# controllers | |||||
'Loop', | |||||
'EvaluateBatchLoop', | |||||
'TrainBatchLoop', | |||||
'Evaluator', | |||||
'Trainer', | |||||
# dataloaders TODO 需要把 mix_dataloader 的搞定 | |||||
# dataset | |||||
'DataSet', | |||||
'FieldArray', | |||||
'Instance', | |||||
'ApplyResultException', | |||||
# drivers | |||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
@@ -7,16 +56,16 @@ __all__ = [ | |||||
"JittorMPIDriver", | "JittorMPIDriver", | ||||
"TorchPaddleDriver", | "TorchPaddleDriver", | ||||
"paddle_to", | |||||
"get_paddle_gpu_str", | |||||
"get_paddle_device_id", | |||||
"paddle_move_data_to_device", | |||||
"torch_paddle_move_data_to_device", | |||||
] | |||||
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.controllers.evaluator import Evaluator | |||||
from fastNLP.core.dataloaders.torch_dataloader import * | |||||
# log | |||||
"logger" | |||||
# | |||||
] | |||||
from .callbacks import * | |||||
from .collators import * | |||||
from .controllers import * | |||||
from .dataloaders import * | |||||
from .dataset import * | |||||
from .drivers import * | from .drivers import * | ||||
from .log import * | |||||
from .utils import * | from .utils import * |
@@ -1,7 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Callback', | 'Callback', | ||||
'Events', | |||||
'EventsList', | |||||
'Event', | |||||
'Filter', | 'Filter', | ||||
'CallbackManager', | 'CallbackManager', | ||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
@@ -20,7 +19,7 @@ __all__ = [ | |||||
from .callback import Callback | from .callback import Callback | ||||
from .callback_events import EventsList, Events, Filter | |||||
from .callback_event import Event, Filter | |||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import CheckpointCallback | from .checkpoint_callback import CheckpointCallback | ||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
@@ -3,10 +3,9 @@ __all__ = [ | |||||
'Callback', | 'Callback', | ||||
] | ] | ||||
from typing import Union, Callable, Dict, Optional, Any | |||||
from typing import Callable, Dict, Optional | |||||
from .callback_events import Events, EventsList, Filter | |||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
from .callback_event import Event, Filter | |||||
class Callback: | class Callback: | ||||
@@ -14,32 +13,35 @@ class Callback: | |||||
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; | ||||
callback 调用时机顺序大概如下 | callback 调用时机顺序大概如下 | ||||
Trainer.__init__(): | Trainer.__init__(): | ||||
on_after_trainer_initialized() | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | Trainer.run(): | ||||
if num_eval_sanity_batch>0: | if num_eval_sanity_batch>0: | ||||
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end() | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | try: | ||||
on_train_begin() | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | while cur_epoch_idx < n_epochs: | ||||
on_train_epoch_begin() | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | while batch_idx_in_epoch<=num_batches_per_epoch: | ||||
on_fetch_data_begin() | |||||
on_fetch_data_end() | |||||
on_train_batch_begin() | |||||
on_before_backward() | |||||
on_after_backward() | |||||
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end() | |||||
on_train_epoch_end() | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | except BaseException: | ||||
self.on_exception() | |||||
self.on_exception(trainer, exception) | |||||
finally: | finally: | ||||
on_train_end() | |||||
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 | |||||
on_train_end(trainer) | |||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 | |||||
的时间调用。 | |||||
""" | """ | ||||
def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
@@ -294,18 +296,14 @@ class _CallbackWrapper(Callback): | |||||
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 | ||||
这一个 callback 函数; | 这一个 callback 函数; | ||||
""" | """ | ||||
def __init__(self, event: Union[Events, EventsList], fn: Callable): | |||||
def __init__(self, event: Event, fn: Callable): | |||||
r""" | r""" | ||||
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList'; | |||||
:param event: 具体的 callback 时机,例如 'on_train_begin' 等; | |||||
:param fn: 用户定制的 callback 函数; | :param fn: 用户定制的 callback 函数; | ||||
""" | """ | ||||
self.fn = fn | self.fn = fn | ||||
if isinstance(event, EventsList): | |||||
for each_event in event: | |||||
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn) | |||||
setattr(self, each_event.value, _filter(fn)) | |||||
elif isinstance(event, _SingleEventState): | |||||
if isinstance(event, Event): | |||||
_filter = Filter(event.every, event.once, event.filter_fn) | _filter = Filter(event.every, event.once, event.filter_fn) | ||||
setattr(self, event.value, _filter(fn)) | setattr(self, event.value, _filter(fn)) | ||||
@@ -0,0 +1,499 @@ | |||||
from typing import Optional, Callable, Dict | |||||
from functools import wraps | |||||
__all__ = [ | |||||
'Event', | |||||
'Filter' | |||||
] | |||||
def check_legality(fn): | |||||
@wraps(fn) | |||||
def wrap(every=None, once=None, filter_fn=None): | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
every = 1 | |||||
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): | |||||
raise ValueError("These three values should be only set one.") | |||||
if (filter_fn is not None) and not callable(filter_fn): | |||||
raise TypeError("Argument filter_fn should be a callable") | |||||
if (every is not None) and not (isinstance(every, int) and every > 0): | |||||
raise ValueError("Argument every should be integer and greater than zero") | |||||
if (once is not None) and not (isinstance(once, int) and once > 0): | |||||
raise ValueError("Argument once should be integer and positive") | |||||
return fn(every=every, once=once, filter_fn=filter_fn) | |||||
return wrap | |||||
class Event: | |||||
every: Optional[int] | |||||
once: Optional[int] | |||||
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | |||||
filter_fn: Optional[Callable] = None): | |||||
""" | |||||
请勿直接使用本对象,而是通过调用 Event.on_after_trainer_initialized() 等方式调用。 | |||||
:param value: Trainer 的 callback 时机。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
""" | |||||
self.every = every | |||||
self.once = once | |||||
self.filter_fn = filter_fn | |||||
self.value = value | |||||
def __str__(self): | |||||
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once, | |||||
self.filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_trainer_initialized(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_trainer_initialized 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_sanity_check_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_sanity_check_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_sanity_check_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_sanity_check_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_epoch_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_epoch_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_epoch_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_epoch_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_fetch_data_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_fetch_data_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_fetch_data_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_fetch_data_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_batch_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_batch_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_train_batch_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_train_batch_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_exception(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_exception 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_save_model(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_save_model 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_model(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_model 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_save_checkpoint(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_save_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_checkpoint(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_load_checkpoint(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_load_checkpoint 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_backward(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_backward 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_backward(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_backward 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_optimizers_step(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_optimizers_step 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_optimizers_step(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_optimizers_step 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_before_zero_grad(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_before_zero_grad 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_after_zero_grad(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_after_zero_grad 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_evaluate_begin(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_evaluate_begin 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) | |||||
@staticmethod | |||||
@check_legality | |||||
def on_evaluate_end(every=None, once=None, filter_fn=None): | |||||
""" | |||||
当 Trainer 运行到 on_evaluate_end 时 | |||||
以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 | |||||
:param int every: 触发了多少次,才真正运行一次。 | |||||
:param bool once: 是否只在第一次运行后就不再执行了。 | |||||
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 | |||||
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 | |||||
:return: | |||||
""" | |||||
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) | |||||
class Filter: | |||||
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
# check legality | |||||
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
every = 1 | |||||
# 设置变量,包括全局变量; | |||||
self.num_called = 0 | |||||
self.num_executed = 0 | |||||
if every is not None: | |||||
self._every = every | |||||
self._filter = self.every_filter | |||||
elif once is not None: | |||||
self._once = once | |||||
self._filter = self.once_filter | |||||
else: | |||||
self._filter = filter_fn | |||||
def __call__(self, fn: Callable): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs) -> Callable: | |||||
self.num_called += 1 | |||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | |||||
return fn(*args, **kwargs) | |||||
wrapper.__fastNLP_filter__ = self | |||||
return wrapper | |||||
def every_filter(self, *args): | |||||
return self.num_called % self._every == 0 | |||||
def once_filter(self, *args): | |||||
return self.num_called == self._once | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
通过该函数来保存该 `Filter` 的状态; | |||||
""" | |||||
return {"num_called": self.num_called, "num_executed": self.num_executed} | |||||
def load_state_dict(self, state: Dict): | |||||
r""" | |||||
通过该函数来加载 `Filter` 的状态; | |||||
:param state: 通过 `Filter.state_dict` 函数保存的状态元组; | |||||
""" | |||||
self.num_called = state["num_called"] | |||||
self.num_executed = state["num_executed"] | |||||
@@ -1,206 +0,0 @@ | |||||
from enum import Enum, unique | |||||
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict | |||||
from types import DynamicClassAttribute | |||||
from functools import wraps | |||||
__all__ = [ | |||||
'Events', | |||||
'EventsList', | |||||
'Filter' | |||||
] | |||||
class _SingleEventState: | |||||
every: Optional[int] | |||||
once: Optional[int] | |||||
def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, | |||||
filter_fn: Optional[Callable] = None, name: Optional[str] = None): | |||||
# 具体的检测参数对错的逻辑放在具体的 Filter 里; | |||||
if every is None and once is None and filter_fn is None: | |||||
self.every = 1 | |||||
self.once = None | |||||
self.filter_fn = None | |||||
else: | |||||
self.every = every | |||||
self.once = once | |||||
self.filter_fn = filter_fn | |||||
if not hasattr(self, "_value_"): | |||||
self._value_ = value | |||||
if not hasattr(self, "_name_") and name is not None: | |||||
self._name_ = name | |||||
# copied to be compatible to enum | |||||
@DynamicClassAttribute | |||||
def name(self) -> str: | |||||
"""The name of the Enum member.""" | |||||
return self._name_ | |||||
@DynamicClassAttribute | |||||
def value(self) -> str: | |||||
"""The value of the Enum member.""" | |||||
return self._value_ | |||||
def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
return _SingleEventState(self.value, every, once, filter_fn, self.name) | |||||
def __str__(self): | |||||
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once, | |||||
self.filter_fn) | |||||
def __eq__(self, other) -> bool: | |||||
if isinstance(other, _SingleEventState): | |||||
return self.name == other.name | |||||
elif isinstance(other, str): | |||||
return self.name == other | |||||
else: | |||||
raise NotImplemented | |||||
def __hash__(self): | |||||
return hash(self._name_) | |||||
def __or__(self, other) -> "EventsList": | |||||
return EventsList() | self | other | |||||
class EventEnum(_SingleEventState, Enum): | |||||
pass | |||||
@unique | |||||
class Events(EventEnum): | |||||
on_after_trainer_initialized = "on_after_trainer_initialized" | |||||
on_sanity_check_begin = "on_sanity_check_begin" | |||||
on_sanity_check_end = "on_sanity_check_end" | |||||
on_train_begin = "on_train_begin" | |||||
on_train_end = "on_train_end" | |||||
on_train_epoch_begin = "on_train_epoch_begin" | |||||
on_train_epoch_end = "on_train_epoch_end" | |||||
on_fetch_data_begin = "on_fetch_data_begin" | |||||
on_fetch_data_end = "on_fetch_data_end" | |||||
on_train_batch_begin = "on_train_batch_begin" | |||||
on_train_batch_end = "on_train_batch_end" | |||||
on_exception = "on_exception" | |||||
on_save_model = "on_save_model" | |||||
on_load_model = "on_load_model" | |||||
on_save_checkpoint = "on_save_checkpoint" | |||||
on_load_checkpoint = "on_load_checkpoint" | |||||
on_before_backward = "on_before_backward" | |||||
on_after_backward = "on_after_backward" | |||||
on_before_optimizers_step = "on_before_optimizers_step" | |||||
on_after_optimizers_step = "on_after_optimizers_step" | |||||
on_before_zero_grad = "on_before_zero_grad" | |||||
on_after_zero_grad = "on_after_zero_grad" | |||||
on_evaluate_begin = "on_evaluate_begin" | |||||
on_evaluate_end = "on_evaluate_end" | |||||
class EventsList: | |||||
"""Collection of events stacked by operator `__or__`. | |||||
""" | |||||
def __init__(self) -> None: | |||||
self._events = [] # type: List[Union[Events, _SingleEventState]] | |||||
def _append(self, event: Union[Events, _SingleEventState]) -> None: | |||||
if not isinstance(event, (Events, _SingleEventState)): | |||||
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}") | |||||
self._events.append(event) | |||||
def __getitem__(self, item: int) -> Union[Events, _SingleEventState]: | |||||
return self._events[item] | |||||
def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]: | |||||
return iter(self._events) | |||||
def __len__(self) -> int: | |||||
return len(self._events) | |||||
def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList": | |||||
self._append(event=other) | |||||
return self | |||||
class Filter: | |||||
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): | |||||
r""" | |||||
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; | |||||
:param every: 表示一个函数隔多少次运行一次; | |||||
:param once: 表示一个函数只在第多少次时运行一次; | |||||
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 | |||||
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; | |||||
""" | |||||
if (every is None) and (once is None) and (filter_fn is None): | |||||
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") | |||||
if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): | |||||
raise ValueError("These three values should be only set one.") | |||||
if (filter_fn is not None) and not callable(filter_fn): | |||||
raise TypeError("Argument event_filter should be a callable") | |||||
if (every is not None) and not (isinstance(every, int) and every > 0): | |||||
raise ValueError("Argument every should be integer and greater than zero") | |||||
if (once is not None) and not (isinstance(once, int) and once > 0): | |||||
raise ValueError("Argument once should be integer and positive") | |||||
# 设置变量,包括全局变量; | |||||
self.num_called = 0 | |||||
self.num_executed = 0 | |||||
if every is not None: | |||||
self._every = every | |||||
self._filter = self.every_filter | |||||
elif once is not None: | |||||
self._once = once | |||||
self._filter = self.once_filter | |||||
else: | |||||
self._filter = filter_fn | |||||
def __call__(self, fn: Callable): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs) -> Callable: | |||||
self.num_called += 1 | |||||
# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; | |||||
trainer = args[0] | |||||
if self._filter(self, trainer): | |||||
self.num_executed += 1 | |||||
return fn(*args, **kwargs) | |||||
wrapper.__fastNLP_filter__ = self | |||||
return wrapper | |||||
def every_filter(self, *args): | |||||
return self.num_called % self._every == 0 | |||||
def once_filter(self, *args): | |||||
return self.num_called == self._once | |||||
def state_dict(self) -> Dict: | |||||
r""" | |||||
通过该函数来保存该 `Filter` 的状态; | |||||
""" | |||||
return {"num_called": self.num_called, "num_executed": self.num_executed} | |||||
def load_state_dict(self, state: Dict): | |||||
r""" | |||||
通过该函数来加载 `Filter` 的状态; | |||||
:param state: 通过 `Filter.state_dict` 函数保存的状态元组; | |||||
""" | |||||
self.num_called = state["num_called"] | |||||
self.num_executed = state["num_executed"] | |||||
@@ -6,7 +6,7 @@ __all__ = [ | |||||
'CallbackManager' | 'CallbackManager' | ||||
] | ] | ||||
from .callback_events import Events | |||||
from .callback_event import Event | |||||
from .callback import Callback | from .callback import Callback | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .progress_callback import ProgressCallback, choose_progress_callback | from .progress_callback import ProgressCallback, choose_progress_callback | ||||
@@ -110,7 +110,7 @@ class CallbackManager: | |||||
def initialize_class_callbacks(self): | def initialize_class_callbacks(self): | ||||
r""" | r""" | ||||
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 | ||||
一个个 callback 时机,也就是 `Events` 的类别; | |||||
一个个 callback 时机,也就是 `Event` 的类别; | |||||
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; | ||||
:param callbacks: | :param callbacks: | ||||
@@ -127,11 +127,12 @@ class CallbackManager: | |||||
:param callback: 一个具体的 callback 实例; | :param callback: 一个具体的 callback 实例; | ||||
""" | """ | ||||
self.all_callbacks.append(callback) | self.all_callbacks.append(callback) | ||||
for name, member in Events.__members__.items(): | |||||
_fn = getattr(callback, member.value) | |||||
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): | |||||
self.callback_fns[member.value].append(_fn) | |||||
self.extract_callback_filter_state(callback.callback_name, _fn) | |||||
for name, member in Event.__dict__.items(): | |||||
if isinstance(member, staticmethod): | |||||
_fn = getattr(callback, name) | |||||
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)): | |||||
self.callback_fns[name].append(_fn) | |||||
self.extract_callback_filter_state(callback.callback_name, _fn) | |||||
def extract_callback_filter_state(self, callback_name, callback_fn): | def extract_callback_filter_state(self, callback_name, callback_fn): | ||||
r""" | r""" | ||||
@@ -161,7 +161,6 @@ class MonitorUtility: | |||||
return monitor_name | return monitor_name | ||||
class HasMonitorCallback(MonitorUtility, Callback): | class HasMonitorCallback(MonitorUtility, Callback): | ||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | def __init__(self, monitor, larger_better, must_have_monitor=False): | ||||
""" | """ | ||||
@@ -1,4 +1,20 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Collator' | |||||
'Collator', | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
] | ] | ||||
from .collator import Collator | from .collator import Collator | ||||
from .padders import * |
@@ -65,12 +65,16 @@ def _get_backend() -> str: | |||||
return catch_backend[0] | return catch_backend[0] | ||||
# 方式 (2) | # 方式 (2) | ||||
for backend in CHECK_BACKEND: | |||||
if backend in sys.modules: | |||||
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") | |||||
return backend | |||||
for key, module in sys.modules.items(): | for key, module in sys.modules.items(): | ||||
catch_backend = _check_module(module) | catch_backend = _check_module(module) | ||||
if catch_backend: | if catch_backend: | ||||
break | break | ||||
if len(catch_backend): | if len(catch_backend): | ||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | return catch_backend[0] | ||||
return 'numpy' | return 'numpy' | ||||
@@ -227,7 +231,7 @@ class Collator: | |||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | ||||
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 | |||||
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert backend in SUPPORTED_BACKENDS | assert backend in SUPPORTED_BACKENDS | ||||
@@ -0,0 +1,30 @@ | |||||
__all__ = [ | |||||
'NumpyNumberPadder', | |||||
'NumpySequencePadder', | |||||
"NumpyTensorPadder", | |||||
"Padder", | |||||
"NullPadder", | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder', | |||||
"PaddleNumberPadder", | |||||
"PaddleTensorPadder", | |||||
"PaddleSequencePadder", | |||||
"get_padded_numpy_array", | |||||
] | |||||
from .numpy_padder import * | |||||
from .padder import Padder, NullPadder | |||||
from .raw_padder import * | |||||
from .torch_padder import * | |||||
from .paddle_padder import * | |||||
from .utils import get_padded_numpy_array |
@@ -1,8 +1,3 @@ | |||||
from typing import Dict | |||||
from typing import Sequence, Any, Union, Dict | from typing import Sequence, Any, Union, Dict | ||||
from abc import ABC | from abc import ABC | ||||
@@ -12,7 +7,7 @@ from fastNLP.core.log import logger | |||||
from .padder import Padder, NullPadder | from .padder import Padder, NullPadder | ||||
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | ||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | ||||
from .raw_padder import RawNumberPadder, RawSequencePadder | |||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | |||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | ||||
from .exceptions import * | from .exceptions import * | ||||
@@ -28,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
:param field_name: 方便报错的。 | :param field_name: 方便报错的。 | ||||
:return: | :return: | ||||
""" | """ | ||||
assert len(batch_field)!=0, "Empty batch encountered." | |||||
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | ||||
if pad_val is None: | if pad_val is None: | ||||
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | ||||
@@ -68,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return NullPadder() | return NullPadder() | ||||
# 再检查所有的元素 type 是否一致 | # 再检查所有的元素 type 是否一致 | ||||
ele_dtypes = set([v[1] for v in catalog.values()]) | |||||
try: | |||||
ele_dtypes = set([v[1] for v in catalog.values()]) | |||||
except TypeError: | |||||
ele_dtypes = set([str(v[1]) for v in catalog.values()]) | |||||
num_eletypes = len(ele_dtypes) | num_eletypes = len(ele_dtypes) | ||||
if num_eletypes != 1: | if num_eletypes != 1: | ||||
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ | msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ | ||||
@@ -80,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
depth = depths.pop() | depth = depths.pop() | ||||
shape_len = shape_lens.pop() | shape_len = shape_lens.pop() | ||||
ele_dtype = ele_dtypes.pop() | |||||
ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了 | |||||
# 需要由 padder 自己决定是否能够 pad 。 | # 需要由 padder 自己决定是否能够 pad 。 | ||||
try: | try: | ||||
@@ -93,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | ||||
if backend == 'raw': | if backend == 'raw': | ||||
@@ -103,14 +103,21 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | |||||
if depth == 1 and shape_len != 0: | |||||
if backend == 'numpy': | |||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
# 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行 | |||||
if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)): | |||||
if backend == 'raw': | |||||
return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||||
elif backend == 'numpy': | |||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||||
else: | |||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | |||||
if shape_len != 0 and depth>1: | if shape_len != 0 and depth>1: | ||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
@@ -179,23 +186,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: | |||||
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | ||||
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | ||||
return catalog | return catalog | ||||
""" | |||||
from numbers import Number | |||||
issubclass(type(3), Number) # True | |||||
issubclass(type(3.1), Number) # True | |||||
issubclass(type('3'), Number) # False | |||||
issubclass(type(True), Number) # True | |||||
issubclass(type(np.zeros(3)[0]), Number) # True | |||||
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True | |||||
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 | |||||
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) | |||||
""" | |||||
@@ -66,7 +66,7 @@ class NumpySequencePadder(Padder): | |||||
class NumpyTensorPadder(Padder): | class NumpyTensorPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | """ | ||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field | |||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||||
:param pad_val: pad 的值是多少。 | :param pad_val: pad 的值是多少。 | ||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | ||||
@@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | |||||
if not isinstance(batch_field[0], np.ndarray): | |||||
batch_field = [np.array(field.tolist()) for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
@@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||||
if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | ||||
@@ -74,13 +74,20 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
elif is_numpy_generic_class(ele_dtype): | elif is_numpy_generic_class(ele_dtype): | ||||
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) | dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) | ||||
else: | else: | ||||
dtype == ele_dtype | |||||
dtype = ele_dtype | |||||
return dtype | return dtype | ||||
class PaddleNumberPadder(Padder): | class PaddleNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder): | |||||
class PaddleSequencePadder(Padder): | class PaddleSequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder): | |||||
class PaddleTensorPadder(Padder): | class PaddleTensorPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | """ | ||||
目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 | |||||
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
:param pad_val: pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | |||||
if not isinstance(batch_field[0], paddle.Tensor): | |||||
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
if isinstance(dtype, np.dtype): | if isinstance(dtype, np.dtype): | ||||
@@ -174,6 +195,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | """ | ||||
shapes = get_shape(batch_field) | shapes = get_shape(batch_field) | ||||
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) | tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) | ||||
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | tensor = fill_tensor(batch_field, tensor, dtype=dtype) | ||||
return tensor | return tensor |
@@ -1,4 +1,8 @@ | |||||
__all__ = [ | |||||
"RawNumberPadder", | |||||
"RawSequencePadder", | |||||
"RawTensorPadder" | |||||
] | |||||
from .padder import Padder | from .padder import Padder | ||||
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | ||||
@@ -63,3 +67,34 @@ class RawSequencePadder(Padder): | |||||
:return: | :return: | ||||
""" | """ | ||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() | return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() | ||||
class RawTensorPadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
""" | |||||
:param batch_field: | |||||
:param pad_val: | |||||
:param dtype: 该参数无意义。 | |||||
:return: | |||||
""" | |||||
try: | |||||
if not isinstance(batch_field[0], (list, tuple)): | |||||
batch_field = [field.tolist() for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() |
@@ -1,4 +1,8 @@ | |||||
__all__ = [ | |||||
'TorchNumberPadder', | |||||
'TorchSequencePadder', | |||||
'TorchTensorPadder' | |||||
] | |||||
from inspect import isclass | from inspect import isclass | ||||
import numpy as np | import numpy as np | ||||
@@ -37,7 +41,7 @@ def is_torch_tensor(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||||
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | ||||
@@ -97,7 +101,7 @@ class TorchSequencePadder(Padder): | |||||
class TorchTensorPadder(Padder): | class TorchTensorPadder(Padder): | ||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | ||||
""" | """ | ||||
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | |||||
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||||
:param pad_val: 需要 pad 的值。 | :param pad_val: 需要 pad 的值。 | ||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | ||||
@@ -108,6 +112,13 @@ class TorchTensorPadder(Padder): | |||||
@staticmethod | @staticmethod | ||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | |||||
if not isinstance(batch_field[0], torch.Tensor): | |||||
batch_field = [torch.tensor(field.tolist()) for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
@@ -1,6 +1,10 @@ | |||||
__all__ = [ | |||||
'get_padded_numpy_array' | |||||
] | |||||
from typing import Sequence, List | from typing import Sequence, List | ||||
from numbers import Number | |||||
import re | import re | ||||
from inspect import isclass | from inspect import isclass | ||||
@@ -2,8 +2,6 @@ __all__ = [ | |||||
'Loop', | 'Loop', | ||||
'EvaluateBatchLoop', | 'EvaluateBatchLoop', | ||||
'TrainBatchLoop', | 'TrainBatchLoop', | ||||
'State', | |||||
'TrainerState', | |||||
'Evaluator', | 'Evaluator', | ||||
'Trainer', | 'Trainer', | ||||
] | ] | ||||
@@ -17,10 +17,10 @@ from .utils import State, TrainerState | |||||
from .utils.utils import check_evaluate_every | from .utils.utils import check_evaluate_every | ||||
from .evaluator import Evaluator | from .evaluator import Evaluator | ||||
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader | ||||
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList | |||||
from fastNLP.core.callbacks import Callback, CallbackManager | |||||
from fastNLP.core.callbacks.callback import _CallbackWrapper | from fastNLP.core.callbacks.callback import _CallbackWrapper | ||||
from fastNLP.core.callbacks.callback_manager import prepare_callbacks | from fastNLP.core.callbacks.callback_manager import prepare_callbacks | ||||
from fastNLP.core.callbacks.callback_events import _SingleEventState | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
@@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger): | |||||
raise e | raise e | ||||
finally: | finally: | ||||
self.on_train_end() | self.on_train_end() | ||||
self.driver.barrier() | |||||
def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): | ||||
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: | def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: | ||||
@@ -399,7 +398,7 @@ class Trainer(TrainerEventTrigger): | |||||
if self.cur_epoch_idx % evaluate_every == 0: | if self.cur_epoch_idx % evaluate_every == 0: | ||||
self.run_evaluate() | self.run_evaluate() | ||||
def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): | |||||
def add_callback_fn(self, event: Event, fn: Callable): | |||||
r""" | r""" | ||||
在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; | 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; | ||||
这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; | 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; | ||||
@@ -407,19 +406,69 @@ class Trainer(TrainerEventTrigger): | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | ||||
:param fn: 具体的 callback 函数; | :param fn: 具体的 callback 函数; | ||||
""" | """ | ||||
if not isinstance(event, (_SingleEventState, EventsList)): | |||||
raise ValueError("parameter event should only be `Events` or `EventsList` type.") | |||||
if not isinstance(event, Event): | |||||
raise ValueError("parameter event should only be `Event` type.") | |||||
_custom_callback = _CallbackWrapper(event, fn) | _custom_callback = _CallbackWrapper(event, fn) | ||||
self.callback_manager.dissect_one_callback(_custom_callback) | self.callback_manager.dissect_one_callback(_custom_callback) | ||||
@classmethod | @classmethod | ||||
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None): | |||||
def on(cls, event: Event, marker: Optional[str] = None): | |||||
r""" | r""" | ||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | ||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 | |||||
Trainer.__init__(): | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | |||||
if num_eval_sanity_batch>0: | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | |||||
self.on_exception(trainer, exception) | |||||
finally: | |||||
on_train_end(trainer) | |||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 | |||||
特定的时间调用。 | |||||
Example:: | |||||
from fastNLP import Event | |||||
@Trainer.on(Event.on_save_model()) | |||||
def do_something_1(trainer): | |||||
# do something | |||||
# 以上函数会在 Trainer 保存模型时执行。 | |||||
@Trainer.on(Event.on_save_model(once=True)) | |||||
def do_something_2(trainer): | |||||
# do something | |||||
# 以上函数会在 Trainer 保存模型时执行,但只执行一次。 | |||||
@Trainer.on(Event.on_train_batch_begin(every=2)) | |||||
def do_something_3(trainer, batch, indices): | |||||
# do something | |||||
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 | |||||
注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; | 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; | ||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; | |||||
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 | |||||
特定的参数,可以通过上述说明查阅。 | |||||
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, | :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, | ||||
表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer | 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer | ||||
实例使用; | 实例使用; | ||||
@@ -427,9 +476,9 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
def wrapper(fn: Callable) -> Callable: | def wrapper(fn: Callable) -> Callable: | ||||
cls._custom_callbacks[marker].append((event, fn)) | |||||
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] | ||||
_check_valid_parameters_number(fn, callback_fn_args) | _check_valid_parameters_number(fn, callback_fn_args) | ||||
cls._custom_callbacks[marker].append((event, fn)) | |||||
return fn | return fn | ||||
return wrapper | return wrapper | ||||
@@ -441,6 +490,7 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) | _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) | ||||
_own_callbacks.extend(self._custom_callbacks[None]) | _own_callbacks.extend(self._custom_callbacks[None]) | ||||
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().") | |||||
self._custom_callbacks[None] = [] | self._custom_callbacks[None] = [] | ||||
if self.marker is not None: | if self.marker is not None: | ||||
if len(self._custom_callbacks[self.marker]) == 0: | if len(self._custom_callbacks[self.marker]) == 0: | ||||
@@ -14,7 +14,7 @@ else: | |||||
from fastNLP.core.dataset import DataSet as Dataset | from fastNLP.core.dataset import DataSet as Dataset | ||||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
@@ -107,33 +107,33 @@ class JittorDataLoader: | |||||
return len(self.dataset) // self.dataset.batch_size | return len(self.dataset) // self.dataset.batch_size | ||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | ||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> "JittorDataLoader": | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||||
backend=backend) | |||||
return self | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def set_ignore(self, *field_names) -> "JittorDataLoader": | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Ex:: | Ex:: | ||||
@@ -146,18 +146,17 @@ class JittorDataLoader: | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_ignore(*field_names) | self._collate_fn.set_ignore(*field_names) | ||||
return self | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前数据的idx | |||||
获取当前 batch 的 idx | |||||
:return: | :return: | ||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(): | def prepare_jittor_dataloader(): | ||||
... | ... |
@@ -15,8 +15,9 @@ else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
from fastNLP.core.collators.collator import Collator | from fastNLP.core.collators.collator import Collator | ||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||||
class _PaddleDataset(Dataset): | class _PaddleDataset(Dataset): | ||||
@@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader): | |||||
if not isinstance(dataset, _PaddleDataset): | if not isinstance(dataset, _PaddleDataset): | ||||
dataset = _PaddleDataset(dataset) | dataset = _PaddleDataset(dataset) | ||||
if batch_sampler is None: | |||||
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | ||||
return_list=return_list, batch_sampler=batch_sampler, | return_list=return_list, batch_sampler=batch_sampler, | ||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | ||||
@@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader): | |||||
if isinstance(dataset.dataset, FDataSet): | if isinstance(dataset.dataset, FDataSet): | ||||
self._collate_fn = dataset.dataset.collator | self._collate_fn = dataset.dataset.collator | ||||
self._collate_fn.set_backend(backend="paddle") | self._collate_fn.set_backend(backend="paddle") | ||||
# if collate_fn is not None: | |||||
# self._collate_fn.add_collator(collate_fn) | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="paddle") | self._collate_fn = Collator(backend="paddle") | ||||
@@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader): | |||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> "PaddleDataLoader": | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||||
backend=backend) | |||||
return self | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def set_ignore(self, *field_names) -> "PaddleDataLoader": | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Ex:: | Ex:: | ||||
@@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader): | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_ignore(*field_names) | self._collate_fn.set_ignore(*field_names) | ||||
return self | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前数据的idx | |||||
获取当前 batch 的 idx | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader): | |||||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | |||||
return_list: bool = True, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
train_batch_size: int = 1, shuffle: bool = False, | train_batch_size: int = 1, shuffle: bool = False, | ||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | ||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
@@ -3,14 +3,14 @@ __all__ = [ | |||||
'prepare_torch_dataloader' | 'prepare_torch_dataloader' | ||||
] | ] | ||||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
@@ -76,6 +76,10 @@ class TorchDataLoader(DataLoader): | |||||
if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
if sampler is None and batch_sampler is None: | |||||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||||
shuffle=False | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | ||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | ||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -87,9 +91,6 @@ class TorchDataLoader(DataLoader): | |||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | ||||
self._collate_fn = dataset.dataset.collator | self._collate_fn = dataset.dataset.collator | ||||
self._collate_fn.set_backend(backend="torch") | self._collate_fn.set_backend(backend="torch") | ||||
# if collate_fn is not None and collate_fn is not default_collate: | |||||
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||||
# self._collate_fn.add_collator(collate_fn) | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="torch") | self._collate_fn = Collator(backend="torch") | ||||
else: | else: | ||||
@@ -112,31 +113,32 @@ class TorchDataLoader(DataLoader): | |||||
yield data | yield data | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | ||||
pad_fn:Callable=None) -> "TorchDataLoader": | |||||
pad_fn:Callable=None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: 返回 Collator 自身 | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | ||||
return self | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||||
def set_ignore(self, *field_names) -> "TorchDataLoader": | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Ex:: | Ex:: | ||||
@@ -149,24 +151,23 @@ class TorchDataLoader(DataLoader): | |||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | if isinstance(self._collate_fn, Collator): | ||||
self._collate_fn.set_ignore(*field_names) | self._collate_fn.set_ignore(*field_names) | ||||
return self | |||||
return self._collate_fn | |||||
else: | else: | ||||
raise ValueError(f"collate_fn is not fastnlp collator") | |||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前数据的idx | |||||
获取当前 batch 的 idx | |||||
:return: | :return: | ||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | ||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
@@ -0,0 +1,16 @@ | |||||
def indice_collate_wrapper(func): | |||||
""" | |||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
:param func: 需要修饰的函数 | |||||
:return: | |||||
""" | |||||
def wrapper(tuple_data): | |||||
indice, ins_list = [], [] | |||||
for idx, ins in tuple_data: | |||||
indice.append(idx) | |||||
ins_list.append(ins) | |||||
return indice, func(ins_list) | |||||
return wrapper |
@@ -770,17 +770,8 @@ class DataSet: | |||||
df = self.to_pandas() | df = self.to_pandas() | ||||
return df.to_csv(path, encoding="utf-8") | return df.to_csv(path, encoding="utf-8") | ||||
def set_ignore(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
self.collator.set_ignore(*field_names) | |||||
@property | @property | ||||
def collator(self): | |||||
def collator(self) -> Collator: | |||||
if self._collator is None: | if self._collator is None: | ||||
self._collator = Collator() | self._collator = Collator() | ||||
return self._collator | return self._collator |
@@ -22,7 +22,7 @@ from fastNLP.core.utils import ( | |||||
rank_zero_rm | rank_zero_rm | ||||
) | ) | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
ReproducibleSampler, | ReproducibleSampler, | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
RandomSampler, | RandomSampler, | ||||
@@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
return self.model, model.forward | return self.model, model.forward | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]], | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
r""" | r""" | ||||
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 | ||||
@@ -22,7 +22,7 @@ from fastNLP.core.log import logger | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
RandomSampler, | RandomSampler, | ||||
) | ) | ||||
@@ -345,7 +345,7 @@ class PaddleDriver(Driver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = RandomBatchSampler( | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -476,7 +476,7 @@ class PaddleDriver(Driver): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
# RandomBatchSampler 的情况 | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | elif hasattr(dataloader.batch_sampler, "batch_sampler"): | ||||
batch_sampler = dataloader.batch_sampler.batch_sampler | batch_sampler = dataloader.batch_sampler.batch_sampler | ||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
@@ -14,7 +14,7 @@ from fastNLP.core.utils import ( | |||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproducibleBatchSampler, | ReproducibleBatchSampler, | ||||
RandomBatchSampler, | |||||
ReproduceBatchSampler, | |||||
ReproducibleSampler, | ReproducibleSampler, | ||||
RandomSampler, | RandomSampler, | ||||
re_instantiate_sampler, | re_instantiate_sampler, | ||||
@@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -15,7 +15,7 @@ from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | from fastNLP.core.samplers import RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver): | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | else: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -293,7 +293,7 @@ class TorchDriver(Driver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " | ||||
"`ReproducibleSampler`.") | "`ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = RandomBatchSampler( | |||||
sampler = ReproduceBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -407,7 +407,7 @@ class TorchDriver(Driver): | |||||
res.shuffle = True | res.shuffle = True | ||||
else: | else: | ||||
res.shuffle = False | res.shuffle = False | ||||
# RandomBatchSampler 的情况 | |||||
# ReproduceBatchSampler 的情况 | |||||
elif hasattr(dataloader.batch_sampler, "batch_sampler"): | elif hasattr(dataloader.batch_sampler, "batch_sampler"): | ||||
batch_sampler = dataloader.batch_sampler.batch_sampler | batch_sampler = dataloader.batch_sampler.batch_sampler | ||||
res.sampler = batch_sampler.sampler | res.sampler = batch_sampler.sampler | ||||
@@ -0,0 +1,25 @@ | |||||
__all__ = [ | |||||
'print' | |||||
] | |||||
from .logger import logger | |||||
def print(*args, sep=' ', end='\n', file=None, flush=False): | |||||
""" | |||||
用来重定向 print 函数至 logger.info 的函数。 | |||||
Example: | |||||
from fastNLP import print | |||||
print("This is a test") # 等价于调用了 logger.info("This is a test") | |||||
:param args: 需要打印的内容 | |||||
:param sep: 存在多个输入时,使用的间隔。 | |||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。 | |||||
:param file: 该参数无意义。 | |||||
:param flush: 该参数无意义。 | |||||
:return: | |||||
""" | |||||
line = sep.join(args) | |||||
logger.info(line) |
@@ -14,9 +14,10 @@ __all__ = [ | |||||
"UnrepeatedSortedSampler", | "UnrepeatedSortedSampler", | ||||
"UnrepeatedSequentialSampler", | "UnrepeatedSequentialSampler", | ||||
"RandomBatchSampler", | |||||
"ReproduceBatchSampler", | |||||
"BucketedBatchSampler", | "BucketedBatchSampler", | ||||
"ReproducibleBatchSampler", | "ReproducibleBatchSampler", | ||||
"RandomBatchSampler", | |||||
"re_instantiate_sampler" | "re_instantiate_sampler" | ||||
] | ] | ||||
@@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling | |||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | ||||
from .utils import re_instantiate_sampler | from .utils import re_instantiate_sampler | ||||
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler | ||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||||
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler | |||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'BucketedBatchSampler', | 'BucketedBatchSampler', | ||||
"ReproduceBatchSampler", | |||||
"RandomBatchSampler" | "RandomBatchSampler" | ||||
] | ] | ||||
@@ -7,7 +8,6 @@ import math | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
from itertools import chain | from itertools import chain | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
@@ -54,13 +54,12 @@ class ReproducibleBatchSampler: | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | ||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||||
class ReproduceBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
可以使得 batch_sampler 对象状态恢复的 wrapper 。 | 可以使得 batch_sampler 对象状态恢复的 wrapper 。 | ||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 | |||||
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 | ||||
:param batch_size: 每个 batch 的大小是多少。 | :param batch_size: 每个 batch 的大小是多少。 | ||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | ||||
@@ -143,7 +142,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): | ||||
@@ -158,6 +157,211 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size | ||||
class RandomBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, | |||||
drop_last: bool = False, seed: int = 0, **kwargs): | |||||
""" | |||||
随机分 batch 的 batch_sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | |||||
:param batch_size: 每个 batch 的大小 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
""" | |||||
super().__init__() | |||||
self.dataset = dataset | |||||
self.batch_size = batch_size | |||||
self.shuffle = shuffle | |||||
self.drop_last = drop_last | |||||
self.seed = seed | |||||
self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 | |||||
# 多卡的相关的参数 | |||||
self.num_replicas = kwargs.get("num_replicas", 1) | |||||
self.rank = kwargs.get("rank", 0) | |||||
self.epoch = kwargs.get("epoch", -1) | |||||
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; | |||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | |||||
self.during_iter = kwargs.get("during_iter", False) | |||||
# 以下变量为内部使用恢复状态的变量。 | |||||
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | |||||
"during an unfinished iteration." | |||||
assert num_replicas > 0 and isinstance(num_replicas, int) | |||||
assert isinstance(rank, int) and 0 <= rank < num_replicas | |||||
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; | |||||
self.num_replicas = num_replicas | |||||
self.rank = rank | |||||
self.pad = pad | |||||
return self | |||||
def __iter__(self): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | |||||
self.num_consumed_samples = 0 | |||||
self.during_iter = True | |||||
indices = list(range(len(self.dataset))) | |||||
if self.shuffle: | |||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||||
_batches = [] | |||||
for _i in range(self.old_num_replicas): | |||||
_indices = indices[_i:len(indices):self.old_num_replicas] | |||||
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch) | |||||
_batches.append(__batches) | |||||
batches = list(chain(*[_ for _ in zip(*_batches)])) | |||||
indices = list(chain(*batches)) | |||||
indices = indices[self.num_consumed_samples:] | |||||
# 取出这个 rank , | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch) | |||||
batches = list(map(list, batches)) | |||||
else: | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
_num_batches = len(indices) // self.batch_size | |||||
if _num_batches == 0: | |||||
batches = [indices] | |||||
else: | |||||
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches))) | |||||
if len(indices)%self.batch_size!=0: | |||||
batches.append(indices[_num_batches*self.batch_size:]) | |||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||||
if len(batches) > 0: | |||||
if len(batches[-1])<self.batch_size: | |||||
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。 | |||||
else: | |||||
batches.append([batches[-1][0]]) | |||||
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank: | |||||
if len(batches): | |||||
batches[-1].pop(-1) | |||||
if len(batches[-1])==0: | |||||
batches.pop(-1) | |||||
assert sum(map(len, batches)) == self.num_left_samples | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||||
batches = batches[:-1] | |||||
for batch in batches: | |||||
self.num_consumed_samples += self.num_replicas * len(batch) | |||||
yield list(map(int, batch)) | |||||
self.during_iter = False | |||||
self.num_consumed_samples = 0 | |||||
self.old_batch_size = self.batch_size | |||||
self.old_num_replicas = self.num_replicas | |||||
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 | |||||
self.epoch -= 1 | |||||
def batchify(self, indices, batch_size, seed): | |||||
""" | |||||
将 indices 分为 batches | |||||
:param sorted_indices: List[int] | |||||
:param batch_size: int | |||||
:param seed: int | |||||
:return: List[List[int]] | |||||
""" | |||||
# 实际的 bucket 大小 | |||||
rng = np.random.default_rng(abs(seed)) | |||||
rng.shuffle(indices) | |||||
num_samples = 0 | |||||
batches = [] | |||||
while num_samples<len(indices): | |||||
batches.append(indices[num_samples:num_samples+batch_size]) | |||||
num_samples += batch_size | |||||
return batches | |||||
def set_epoch(self, epoch): | |||||
self.epoch = epoch | |||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
if self.drop_last: | |||||
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size | |||||
else: | |||||
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ | |||||
(self.num_left_samples + self.batch_size - 1) // self.batch_size | |||||
@property | |||||
def total_size(self): | |||||
""" | |||||
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、 | |||||
大于或者小于len(dataset) | |||||
:return: | |||||
""" | |||||
return self.num_consumed_samples + self.num_replicas*self.num_left_samples | |||||
@property | |||||
def num_left_samples(self): | |||||
""" | |||||
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。 | |||||
:return: | |||||
""" | |||||
num_consumed_samples = self.num_consumed_samples | |||||
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \ | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | |||||
def __len__(self)->int: | |||||
""" | |||||
返回当前 sampler 还会返回多少个 batch 的数据 | |||||
:return: | |||||
""" | |||||
num_sampler_per_rank = self.total_size//self.num_replicas | |||||
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ | |||||
(num_sampler_per_rank+self.batch_size-1)//self.batch_size | |||||
return num_batches | |||||
def state_dict(self) -> Dict: | |||||
if self.old_batch_size != self.batch_size: | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||||
" consumed. ") | |||||
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, | |||||
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_replicas': self.num_replicas} | |||||
return states | |||||
def load_state_dict(self, states: Dict): | |||||
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; | |||||
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ | |||||
"during an unfinished iteration." | |||||
assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ | |||||
f"we cannot use {self.__class__.__name__} to load it." | |||||
length = states['length'] | |||||
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ | |||||
"and current dataset." | |||||
self.seed = states['seed'] | |||||
self.epoch = states['epoch'] | |||||
self.num_consumed_samples = states['num_consumed_samples'] | |||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | |||||
self.num_consumed_samples = 0 | |||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | |||||
self.old_batch_size = states['batch_size'] | |||||
self.old_num_replicas = states['num_replicas'] | |||||
class BucketedBatchSampler(ReproducibleBatchSampler): | class BucketedBatchSampler(ReproducibleBatchSampler): | ||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
@@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet | |||||
class ReproducibleSampler: | class ReproducibleSampler: | ||||
""" | """ | ||||
可复现的 Sampler 对象。 | |||||
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler | ||||
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 | ||||
@@ -54,13 +56,12 @@ class RandomSampler(ReproducibleSampler): | |||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
:param dataset: 实现了 __len__ 方法的数据容器 | :param dataset: 实现了 __len__ 方法的数据容器 | ||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | :param shuffle: 是否在每次 iterate 的时候打乱顺序。 | ||||
:param seed: 随机数种子。 | :param seed: 随机数种子。 | ||||
:param kwargs: 用户不需要使用,fastNLP 内部使用 | :param kwargs: 用户不需要使用,fastNLP 内部使用 | ||||
""" | """ | ||||
super(RandomSampler, self).__init__() | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.shuffle = shuffle | self.shuffle = shuffle | ||||
self.seed = seed | self.seed = seed | ||||
@@ -21,7 +21,6 @@ __all__ = [ | |||||
'nullcontext', | 'nullcontext', | ||||
'pretty_table_printer', | 'pretty_table_printer', | ||||
'Option', | 'Option', | ||||
'indice_collate_wrapper', | |||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | 'rank_zero_rm', | ||||
@@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
from ..dataloaders.utils import indice_collate_wrapper | |||||
@@ -1,5 +1,5 @@ | |||||
import functools | import functools | ||||
class DummyClass: | class DummyClass: | ||||
def __call__(self, *args, **kwargs): | |||||
return | |||||
def __init__(self, *args, **kwargs): | |||||
pass |
@@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]): | |||||
else: | else: | ||||
return data.cuda(get_paddle_device_id(device)) | return data.cuda(get_paddle_device_id(device)) | ||||
def get_paddle_gpu_str(device: Union[str, int]): | def get_paddle_gpu_str(device: Union[str, int]): | ||||
""" | """ | ||||
获得 `gpu:x` 类型的设备名 | 获得 `gpu:x` 类型的设备名 | ||||
@@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]): | |||||
return device.replace("cuda", "gpu") | return device.replace("cuda", "gpu") | ||||
return f"gpu:{device}" | return f"gpu:{device}" | ||||
def get_paddle_device_id(device: Union[str, int]): | def get_paddle_device_id(device: Union[str, int]): | ||||
""" | """ | ||||
获得 gpu 的设备id | 获得 gpu 的设备id | ||||
@@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||||
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | ||||
def is_in_paddle_dist(): | def is_in_paddle_dist(): | ||||
""" | """ | ||||
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | ||||
""" | """ | ||||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | ||||
def is_in_fnlp_paddle_dist(): | def is_in_fnlp_paddle_dist(): | ||||
""" | """ | ||||
判断是否处于 FastNLP 拉起的分布式进程中 | 判断是否处于 FastNLP 拉起的分布式进程中 | ||||
""" | """ | ||||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | return FASTNLP_DISTRIBUTED_CHECK in os.environ | ||||
def is_in_paddle_launch_dist(): | def is_in_paddle_launch_dist(): | ||||
""" | """ | ||||
判断是否处于 launch 启动的分布式进程中 | 判断是否处于 launch 启动的分布式进程中 | ||||
@@ -6,7 +6,7 @@ import warnings | |||||
from dataclasses import is_dataclass | from dataclasses import is_dataclass | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from collections import defaultdict, OrderedDict | from collections import defaultdict, OrderedDict | ||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence | |||||
from typing import Tuple, Optional | from typing import Tuple, Optional | ||||
from time import sleep | from time import sleep | ||||
@@ -35,7 +35,6 @@ __all__ = [ | |||||
'nullcontext', | 'nullcontext', | ||||
'pretty_table_printer', | 'pretty_table_printer', | ||||
'Option', | 'Option', | ||||
'indice_collate_wrapper', | |||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | 'rank_zero_rm', | ||||
@@ -513,24 +512,6 @@ class Option(dict): | |||||
self.update(state) | self.update(state) | ||||
def indice_collate_wrapper(func): | |||||
""" | |||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||||
:param func: 需要修饰的函数 | |||||
:return: | |||||
""" | |||||
def wrapper(tuple_data): | |||||
indice, ins_list = [], [] | |||||
for idx, ins in tuple_data: | |||||
indice.append(idx) | |||||
ins_list.append(ins) | |||||
return indice, func(ins_list) | |||||
return wrapper | |||||
_emitted_deprecation_warnings = set() | _emitted_deprecation_warnings = set() | ||||
@@ -332,13 +332,44 @@ class DataBundle: | |||||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | show_progress_bar=show_progress_bar, progress_desc=progress_desc) | ||||
return res | return res | ||||
def set_pad_val(self, *field_names, val=0) -> None: | |||||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | |||||
""" | |||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||||
形式,输出将被直接作为结果输出。 | |||||
:return: self | |||||
""" | |||||
for _, ds in self.iter_datasets(): | for _, ds in self.iter_datasets(): | ||||
ds.set_pad_val(*field_names, val=val) | |||||
ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend, | |||||
pad_fn=pad_fn) | |||||
return self | |||||
def set_input(self, *field_names) -> None: | |||||
def set_ignore(self, *field_names) -> "DataBundle": | |||||
""" | |||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||||
Ex:: | |||||
collator.set_ignore('field1', 'field2') | |||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||||
:return: self | |||||
""" | |||||
for _, ds in self.iter_datasets(): | for _, ds in self.iter_datasets(): | ||||
ds.set_input(*field_names) | |||||
ds.collator.set_ignore(*field_names) | |||||
return self | |||||
def __repr__(self) -> str: | def __repr__(self) -> str: | ||||
_str = '' | _str = '' | ||||
@@ -0,0 +1,208 @@ | |||||
import pytest | |||||
from functools import reduce | |||||
from fastNLP.core.callbacks.callback_event import Event, Filter | |||||
class TestFilter: | |||||
def test_every_filter(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w-1 for w in range(10, 101, 10)] | |||||
# every = 1 | |||||
@Filter(every=1) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
def test_once_filter(self): | |||||
# once = 10 | |||||
@Filter(once=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
def test_extract_filter_from_fn(self): | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_filter_num_called = [] | |||||
_filter_num_executed = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
_filter = _fn.__fastNLP_filter__ | |||||
_filter_num_called.append(_filter.num_called) | |||||
_filter_num_executed.append(_filter.num_executed) | |||||
assert _filter_num_called == list(range(1, 101)) | |||||
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10] | |||||
def _fn(data): | |||||
return data | |||||
assert not hasattr(_fn, "__fastNLP_filter__") | |||||
def test_filter_state_dict(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(50): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 51, 10)] | |||||
# 保存状态 | |||||
state = _fn.__fastNLP_filter__.state_dict() | |||||
# 加载状态 | |||||
_fn.__fastNLP_filter__.load_state_dict(state) | |||||
_res = [] | |||||
for i in range(50, 100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(60, 101, 10)] | |||||
@pytest.mark.torch | |||||
def test_filter_fn_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
class TestCallbackEvents: | |||||
def test_every(self): | |||||
# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; | |||||
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
event_state = Event.on_train_begin(every=10) | |||||
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 101, 10)] | |||||
def test_once(self): | |||||
event_state = Event.on_train_begin(once=10) | |||||
@Filter(once=event_state.once) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
@pytest.mark.torch | |||||
def test_callback_events_torch(): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
event_state = Event.on_train_begin(filter_fn=filter_fn) | |||||
@Filter(filter_fn=event_state.filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
@@ -1,157 +0,0 @@ | |||||
import pytest | |||||
from functools import reduce | |||||
from fastNLP.core.callbacks.callback_events import Events, Filter | |||||
class TestFilter: | |||||
def test_params_check(self): | |||||
# 顺利通过 | |||||
_filter1 = Filter(every=10) | |||||
_filter2 = Filter(once=10) | |||||
_filter3 = Filter(filter_fn=lambda: None) | |||||
# 触发 ValueError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter4 = Filter() | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter." | |||||
# 触发 ValueError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter5 = Filter(every=10, once=10) | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "These three values should be only set one." | |||||
# 触发 TypeError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter6 = Filter(every="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument every should be integer and greater than zero" | |||||
# 触发 TypeError | |||||
with pytest.raises(ValueError) as e: | |||||
_filter7 = Filter(once="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument once should be integer and positive" | |||||
# 触发 TypeError | |||||
with pytest.raises(TypeError) as e: | |||||
_filter7 = Filter(filter_fn="heihei") | |||||
exec_msg = e.value.args[0] | |||||
assert exec_msg == "Argument event_filter should be a callable" | |||||
def test_every_filter(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w-1 for w in range(10, 101, 10)] | |||||
# every = 1 | |||||
@Filter(every=1) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == list(range(100)) | |||||
def test_once_filter(self): | |||||
# once = 10 | |||||
@Filter(once=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [9] | |||||
def test_filter_fn(self): | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
optimizer = SGD(model.parameters(), lr=0.0001) | |||||
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) | |||||
dataloader = DataLoader(dataset=dataset, batch_size=4) | |||||
trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.__heihei_test__ == 10: | |||||
return True | |||||
return False | |||||
@Filter(filter_fn=filter_fn) | |||||
def _fn(trainer, data): | |||||
return data | |||||
_res = [] | |||||
for i in range(100): | |||||
trainer.__heihei_test__ = i | |||||
cu_res = _fn(trainer, i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [10] | |||||
def test_extract_filter_from_fn(self): | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_filter_num_called = [] | |||||
_filter_num_executed = [] | |||||
for i in range(100): | |||||
cu_res = _fn(i) | |||||
_filter = _fn.__fastNLP_filter__ | |||||
_filter_num_called.append(_filter.num_called) | |||||
_filter_num_executed.append(_filter.num_executed) | |||||
assert _filter_num_called == list(range(1, 101)) | |||||
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10] | |||||
def _fn(data): | |||||
return data | |||||
assert not hasattr(_fn, "__fastNLP_filter__") | |||||
def test_filter_state_dict(self): | |||||
# every = 10 | |||||
@Filter(every=10) | |||||
def _fn(data): | |||||
return data | |||||
_res = [] | |||||
for i in range(50): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(10, 51, 10)] | |||||
# 保存状态 | |||||
state = _fn.__fastNLP_filter__.state_dict() | |||||
# 加载状态 | |||||
_fn.__fastNLP_filter__.load_state_dict(state) | |||||
_res = [] | |||||
for i in range(50, 100): | |||||
cu_res = _fn(i) | |||||
if cu_res is not None: | |||||
_res.append(cu_res) | |||||
assert _res == [w - 1 for w in range(60, 101, 10)] | |||||
@@ -2,9 +2,6 @@ import os | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
import time | import time | ||||
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -216,9 +218,9 @@ def test_model_checkpoint_callback_2( | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | path = Path.cwd().joinpath("test_model_checkpoint") | ||||
path.mkdir(exist_ok=True, parents=True) | path.mkdir(exist_ok=True, parents=True) | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( | |||||
if version == 0: | if version == 0: | ||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | |||||
CheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
folder=path, | folder=path, | ||||
every_n_epochs=None, | every_n_epochs=None, | ||||
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( | |||||
topk=None, | topk=None, | ||||
last=False, | last=False, | ||||
on_exception=None, | on_exception=None, | ||||
model_save_fn=model_save_fn | |||||
model_save_fn=model_save_fn, | |||||
save_object="trainer" | |||||
) | ) | ||||
] | ] | ||||
elif version == 1: | elif version == 1: | ||||
callbacks = [ | callbacks = [ | ||||
TrainerCheckpointCallback( | |||||
CheckpointCallback( | |||||
monitor="acc", | monitor="acc", | ||||
folder=path, | folder=path, | ||||
every_n_epochs=None, | every_n_epochs=None, | ||||
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( | |||||
topk=1, | topk=1, | ||||
last=True, | last=True, | ||||
on_exception=None, | on_exception=None, | ||||
model_save_fn=model_save_fn | |||||
model_save_fn=model_save_fn, | |||||
save_object="trainer" | |||||
) | ) | ||||
] | ] | ||||
@@ -12,9 +12,7 @@ import os | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
from pathlib import Path | from pathlib import Path | ||||
import re | import re | ||||
@@ -29,7 +27,11 @@ from torchmetrics import Accuracy | |||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.callbacks import MoreEvaluateCallback | from fastNLP.core.callbacks import MoreEvaluateCallback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
@@ -17,12 +17,13 @@ def test_get_element_shape_dtype(): | |||||
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
@pytest.mark.jittor | |||||
def test_get_padder_run(backend): | def test_get_padder_run(backend): | ||||
if not _NEED_IMPORT_TORCH and backend == 'torch': | if not _NEED_IMPORT_TORCH and backend == 'torch': | ||||
pytest.skip("No torch") | pytest.skip("No torch") | ||||
if not _NEED_IMPORT_PADDLE and backend == 'paddle': | if not _NEED_IMPORT_PADDLE and backend == 'paddle': | ||||
pytest.skip("No paddle") | pytest.skip("No paddle") | ||||
if not _NEED_IMPORT_PADDLE and backend == 'jittor': | |||||
if not _NEED_IMPORT_JITTOR and backend == 'jittor': | |||||
pytest.skip("No jittor") | pytest.skip("No jittor") | ||||
batch_field = [1, 2, 3] | batch_field = [1, 2, 3] | ||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | ||||
@@ -66,6 +67,13 @@ def test_raw_padder(): | |||||
pad_batch = padder(batch_field) | pad_batch = padder(batch_field) | ||||
assert np.shape(pad_batch) == (3, 3, 2) | assert np.shape(pad_batch) == (3, 3, 2) | ||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, list) | |||||
assert np.shape(pad_batch) == (3, 3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | |||||
def test_numpy_padder(): | def test_numpy_padder(): | ||||
backend = 'numpy' | backend = 'numpy' | ||||
@@ -140,3 +148,18 @@ def test_torch_padder(): | |||||
with pytest.raises(InconsistencyError): | with pytest.raises(InconsistencyError): | ||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | ||||
# 可以是 numpy.ndarray | |||||
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, target_type) | |||||
assert pad_batch.shape == (3, 3, 3) | |||||
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 | |||||
# 测试 to numpy | |||||
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] | |||||
padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test') | |||||
pad_batch = padder(batch_field) | |||||
assert isinstance(pad_batch, np.ndarray) | |||||
assert np.shape(pad_batch) == (3, 3, 3) | |||||
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 |
@@ -1,7 +1,7 @@ | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder | |||||
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | |||||
from fastNLP.core.collators.padders.exceptions import DtypeError | from fastNLP.core.collators.padders.exceptions import DtypeError | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
@@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE: | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleNumberPadder: | |||||
class TestPaddleNumberPadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
t_a = padder(a) | t_a = padder(a) | ||||
assert isinstance(t_a, paddle.Tensor) | assert isinstance(t_a, paddle.Tensor) | ||||
@@ -20,9 +20,9 @@ class TestpaddleNumberPadder: | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleSequencePadder: | |||||
class TestPaddleSequencePadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -32,20 +32,20 @@ class TestpaddleSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) | |||||
a = padder([[1], [2, 322]]) | a = padder([[1], [2, 322]]) | ||||
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | ||||
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class TestpaddleTensorPadder: | |||||
class TestPaddleTensorPadder: | |||||
def test_run(self): | def test_run(self): | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3,)), paddle.zeros((2,))] | a = [paddle.zeros((3,)), paddle.zeros((2,))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -74,7 +74,7 @@ class TestpaddleTensorPadder: | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | [[0, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) | |||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] | a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -85,7 +85,7 @@ class TestpaddleTensorPadder: | |||||
]) | ]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) | |||||
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] | a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -96,11 +96,11 @@ class TestpaddleTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
def test_v1(self): | def test_v1(self): | ||||
print(paddle.zeros((3, )).dtype) | print(paddle.zeros((3, )).dtype) |
@@ -23,7 +23,6 @@ class TestRawSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
with pytest.raises(DtypeError): | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) |
@@ -1,81 +1,293 @@ | |||||
import numpy as np | |||||
import pytest | import pytest | ||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.collators.collator import Collator | |||||
def _assert_equal(d1, d2): | |||||
try: | |||||
if 'torch' in str(type(d1)): | |||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
assert (d1 == d2).all().item() | |||||
else: | |||||
assert all(d1 == d2) | |||||
except TypeError: | |||||
assert d1 == d2 | |||||
except ValueError: | |||||
assert (d1 == d2).all() | |||||
def findDictDiff(d1, d2, path=""): | |||||
for k in d1: | |||||
if k in d2: | |||||
if isinstance(d1[k], dict): | |||||
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) | |||||
else: | |||||
_assert_equal(d1[k], d2[k]) | |||||
else: | |||||
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) | |||||
def findListDiff(d1, d2): | |||||
assert len(d1)==len(d2) | |||||
for _d1, _d2 in zip(d1, d2): | |||||
if isinstance(_d1, list): | |||||
findListDiff(_d1, _d2) | |||||
else: | |||||
_assert_equal(_d1, _d2) | |||||
class TestCollator: | class TestCollator: | ||||
@pytest.mark.parametrize('as_numpy', [True, False]) | |||||
def test_auto_collator(self, as_numpy): | |||||
""" | |||||
测试auto_collator的auto_pad功能 | |||||
:param as_numpy: | |||||
:return: | |||||
""" | |||||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||||
'y': [0, 1, 1, 0] * 100}) | |||||
collator = AutoCollator(as_numpy=as_numpy) | |||||
collator.set_input('x', 'y') | |||||
bucket_data = [] | |||||
data = [] | |||||
for i in range(len(dataset)): | |||||
data.append(dataset[i]) | |||||
if len(data) == 40: | |||||
bucket_data.append(data) | |||||
data = [] | |||||
results = [] | |||||
for bucket in bucket_data: | |||||
res = collator(bucket) | |||||
assert res['x'].shape == (40, 5) | |||||
assert res['y'].shape == (40,) | |||||
results.append(res) | |||||
def test_auto_collator_v1(self): | |||||
""" | |||||
测试auto_collator的set_pad_val和set_pad_val功能 | |||||
:return: | |||||
""" | |||||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||||
'y': [0, 1, 1, 0] * 100}) | |||||
collator = AutoCollator(as_numpy=False) | |||||
collator.set_input('x') | |||||
collator.set_pad_val('x', val=-1) | |||||
collator.set_as_numpy(True) | |||||
bucket_data = [] | |||||
data = [] | |||||
for i in range(len(dataset)): | |||||
data.append(dataset[i]) | |||||
if len(data) == 40: | |||||
bucket_data.append(data) | |||||
data = [] | |||||
for bucket in bucket_data: | |||||
res = collator(bucket) | |||||
print(res) | |||||
def test_multicollator(self): | |||||
""" | |||||
测试multicollator功能 | |||||
:return: | |||||
""" | |||||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||||
'y': [0, 1, 1, 0] * 100}) | |||||
collator = AutoCollator(as_numpy=False) | |||||
multi_collator = _MultiCollator(collator) | |||||
multi_collator.set_as_numpy(as_numpy=True) | |||||
multi_collator.set_pad_val('x', val=-1) | |||||
multi_collator.set_input('x') | |||||
bucket_data = [] | |||||
data = [] | |||||
for i in range(len(dataset)): | |||||
data.append(dataset[i]) | |||||
if len(data) == 40: | |||||
bucket_data.append(data) | |||||
data = [] | |||||
for bucket in bucket_data: | |||||
res = multi_collator(bucket) | |||||
print(res) | |||||
@pytest.mark.torch | |||||
def test_run(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
collator = Collator(backend='raw') | |||||
assert raw_pad_batch == collator(dict_batch) | |||||
collator = Collator(backend='raw') | |||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | |||||
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), | |||||
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | |||||
'b': np.array([[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | |||||
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | |||||
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(numpy_pad_lst, collator(list_batch)) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
collator = Collator(backend='torch') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||||
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': torch.FloatTensor([1.1, 2.1]), | |||||
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||||
'numpy': torch.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||||
'b': torch.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='torch') | |||||
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||||
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||||
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(torch_pad_lst, collator(list_batch)) | |||||
def test_pad(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
# 测试 ignore | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 测试 set_pad | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('str', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
# 测试设置 pad 值 | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('nest_lst_int', pad_val=100) | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 设置 backend 和 type | |||||
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
# [{'1'}, {'2'}]] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_4', pad_val=None) | |||||
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('_0', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_2', backend='numpy') | |||||
collator.set_pad('_4', backend='numpy', pad_val=100) | |||||
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
# _single | |||||
collator = Collator() | |||||
collator.set_pad('_single') | |||||
findListDiff(list_batch, collator(list_batch)) | |||||
def test_nest_ignore(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} | |||||
} | |||||
] | |||||
# 测试 ignore | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': {'int':[1, 1]}}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad(('nested_dict', 'c'), pad_val=None) | |||||
collator.set_ignore('str', 'int', 'lst_int') | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': [{'int':1}, {'int':1}]}} | |||||
pad_batch = collator(dict_batch) | |||||
findDictDiff(raw_pad_batch, pad_batch) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad(('nested_dict', 'c'), pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int') | |||||
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) | |||||
pad_batch = collator(dict_batch) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': [1, 1]}} | |||||
findDictDiff(raw_pad_batch, pad_batch) | |||||
@@ -1,293 +0,0 @@ | |||||
import numpy as np | |||||
import pytest | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||||
from fastNLP.core.collators.new_collator import Collator | |||||
def _assert_equal(d1, d2): | |||||
try: | |||||
if 'torch' in str(type(d1)): | |||||
if 'float64' in str(d2.dtype): | |||||
print(d2.dtype) | |||||
assert (d1 == d2).all().item() | |||||
else: | |||||
assert all(d1 == d2) | |||||
except TypeError: | |||||
assert d1 == d2 | |||||
except ValueError: | |||||
assert (d1 == d2).all() | |||||
def findDictDiff(d1, d2, path=""): | |||||
for k in d1: | |||||
if k in d2: | |||||
if isinstance(d1[k], dict): | |||||
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) | |||||
else: | |||||
_assert_equal(d1[k], d2[k]) | |||||
else: | |||||
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) | |||||
def findListDiff(d1, d2): | |||||
assert len(d1)==len(d2) | |||||
for _d1, _d2 in zip(d1, d2): | |||||
if isinstance(_d1, list): | |||||
findListDiff(_d1, _d2) | |||||
else: | |||||
_assert_equal(_d1, _d2) | |||||
class TestCollator: | |||||
@pytest.mark.torch | |||||
def test_run(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
collator = Collator(backend='raw') | |||||
assert raw_pad_batch == collator(dict_batch) | |||||
collator = Collator(backend='raw') | |||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | |||||
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), | |||||
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | |||||
'b': np.array([[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='numpy') | |||||
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | |||||
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | |||||
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(numpy_pad_lst, collator(list_batch)) | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
collator = Collator(backend='torch') | |||||
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||||
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||||
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
'float': torch.FloatTensor([1.1, 2.1]), | |||||
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||||
'numpy': torch.FloatTensor([[1], [0]]), | |||||
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||||
'b': torch.LongTensor( | |||||
[[1, 2], [1, 2]])}} | |||||
findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='torch') | |||||
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||||
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||||
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||||
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(torch_pad_lst, collator(list_batch)) | |||||
def test_pad(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'a': 1, 'b':[1, 2]} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'a': 2, 'b': [1, 2]} | |||||
} | |||||
] | |||||
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||||
# 测试 ignore | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 测试 set_pad | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('str', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
# 测试设置 pad 值 | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('nest_lst_int', pad_val=100) | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# 设置 backend 和 type | |||||
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||||
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
# [{'1'}, {'2'}]] | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_4', pad_val=None) | |||||
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad('_0', pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||||
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('_0', '_3', '_1') | |||||
collator.set_pad('_2', backend='numpy') | |||||
collator.set_pad('_4', backend='numpy', pad_val=100) | |||||
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | |||||
findListDiff(raw_pad_lst, collator(list_batch)) | |||||
# _single | |||||
collator = Collator() | |||||
collator.set_pad('_single') | |||||
findListDiff(list_batch, collator(list_batch)) | |||||
def test_nest_ignore(self): | |||||
dict_batch = [{ | |||||
'str': '1', | |||||
'lst_str': ['1'], | |||||
'int': 1, | |||||
'lst_int': [1], | |||||
'nest_lst_int': [[1]], | |||||
'float': 1.1, | |||||
'lst_float': [1.1], | |||||
'bool': True, | |||||
'numpy': np.ones(1), | |||||
'dict': {'1': '1'}, | |||||
'set': {'1'}, | |||||
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} | |||||
}, | |||||
{ | |||||
'str': '2', | |||||
'lst_str': ['2', '2'], | |||||
'int': 2, | |||||
'lst_int': [1, 2], | |||||
'nest_lst_int': [[1], [1, 2]], | |||||
'float': 2.1, | |||||
'lst_float': [2.1], | |||||
'bool': False, | |||||
'numpy': np.zeros(1), | |||||
'dict': {'1': '2'}, | |||||
'set': {'2'}, | |||||
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} | |||||
} | |||||
] | |||||
# 测试 ignore | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': {'int':[1, 1]}}} | |||||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad(('nested_dict', 'c'), pad_val=None) | |||||
collator.set_ignore('str', 'int', 'lst_int') | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': [{'int':1}, {'int':1}]}} | |||||
pad_batch = collator(dict_batch) | |||||
findDictDiff(raw_pad_batch, pad_batch) | |||||
collator = Collator(backend='raw') | |||||
collator.set_pad(('nested_dict', 'c'), pad_val=1) | |||||
with pytest.raises(BaseException): | |||||
collator(dict_batch) | |||||
collator = Collator(backend='raw') | |||||
collator.set_ignore('str', 'int', 'lst_int') | |||||
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) | |||||
pad_batch = collator(dict_batch) | |||||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||||
'c': [1, 1]}} | |||||
findDictDiff(raw_pad_batch, pad_batch) | |||||
@@ -1,17 +1,20 @@ | |||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from torchmetrics import Accuracy | |||||
import torch.distributed as dist | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from torchmetrics import Accuracy | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
@@ -62,12 +65,11 @@ def model_and_optimizers(): | |||||
return trainer_params | return trainer_params | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger( | |||||
def test_trainer_event_trigger_1( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
@@ -97,8 +99,215 @@ def test_trainer_event_trigger( | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger_2( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
@Trainer.on(Event.on_after_trainer_initialized()) | |||||
def on_after_trainer_initialized(trainer, driver): | |||||
print("on_after_trainer_initialized") | |||||
@Trainer.on(Event.on_sanity_check_begin()) | |||||
def on_sanity_check_begin(trainer): | |||||
print("on_sanity_check_begin") | |||||
@Trainer.on(Event.on_sanity_check_end()) | |||||
def on_sanity_check_end(trainer, sanity_check_res): | |||||
print("on_sanity_check_end") | |||||
@Trainer.on(Event.on_train_begin()) | |||||
def on_train_begin(trainer): | |||||
print("on_train_begin") | |||||
@Trainer.on(Event.on_train_end()) | |||||
def on_train_end(trainer): | |||||
print("on_train_end") | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def on_train_epoch_begin(trainer): | |||||
if trainer.cur_epoch_idx >= 1: | |||||
# 触发 on_exception; | |||||
raise Exception | |||||
print("on_train_epoch_begin") | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def on_train_epoch_end(trainer): | |||||
print("on_train_epoch_end") | |||||
@Trainer.on(Event.on_fetch_data_begin()) | |||||
def on_fetch_data_begin(trainer): | |||||
print("on_fetch_data_begin") | |||||
@Trainer.on(Event.on_fetch_data_end()) | |||||
def on_fetch_data_end(trainer): | |||||
print("on_fetch_data_end") | |||||
@Trainer.on(Event.on_train_batch_begin()) | |||||
def on_train_batch_begin(trainer, batch, indices=None): | |||||
print("on_train_batch_begin") | |||||
@Trainer.on(Event.on_train_batch_end()) | |||||
def on_train_batch_end(trainer): | |||||
print("on_train_batch_end") | |||||
@Trainer.on(Event.on_exception()) | |||||
def on_exception(trainer, exception): | |||||
print("on_exception") | |||||
@Trainer.on(Event.on_before_backward()) | |||||
def on_before_backward(trainer, outputs): | |||||
print("on_before_backward") | |||||
@Trainer.on(Event.on_after_backward()) | |||||
def on_after_backward(trainer): | |||||
print("on_after_backward") | |||||
@Trainer.on(Event.on_before_optimizers_step()) | |||||
def on_before_optimizers_step(trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
@Trainer.on(Event.on_after_optimizers_step()) | |||||
def on_after_optimizers_step(trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
@Trainer.on(Event.on_before_zero_grad()) | |||||
def on_before_zero_grad(trainer, optimizers): | |||||
print("on_before_zero_grad") | |||||
@Trainer.on(Event.on_after_zero_grad()) | |||||
def on_after_zero_grad(trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
@Trainer.on(Event.on_evaluate_begin()) | |||||
def on_evaluate_begin(trainer): | |||||
print("on_evaluate_begin") | |||||
@Trainer.on(Event.on_evaluate_end()) | |||||
def on_evaluate_end(trainer, results): | |||||
print("on_evaluate_end") | |||||
with pytest.raises(Exception): | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
Event_attrs = Event.__dict__ | |||||
for k, v in Event_attrs.items(): | |||||
if isinstance(v, staticmethod): | |||||
assert k in output[0] | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger_3( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
import re | |||||
once_message_1 = "This message should be typed 1 times." | |||||
once_message_2 = "test_filter_fn" | |||||
once_message_3 = "once message 3" | |||||
twice_message = "twice message hei hei" | |||||
@Trainer.on(Event.on_train_epoch_begin(every=2)) | |||||
def train_epoch_begin_1(trainer): | |||||
print(once_message_1) | |||||
@Trainer.on(Event.on_train_epoch_begin()) | |||||
def train_epoch_begin_2(trainer): | |||||
print(twice_message) | |||||
@Trainer.on(Event.on_train_epoch_begin(once=2)) | |||||
def train_epoch_begin_3(trainer): | |||||
print(once_message_3) | |||||
def filter_fn(filter, trainer): | |||||
if trainer.cur_epoch_idx == 1: | |||||
return True | |||||
else: | |||||
return False | |||||
@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn)) | |||||
def test_filter_fn(trainer): | |||||
print(once_message_2) | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
once_pattern_1 = re.compile(once_message_1) | |||||
once_pattern_2 = re.compile(once_message_2) | |||||
once_pattern_3 = re.compile(once_message_3) | |||||
twice_pattern = re.compile(twice_message) | |||||
once_res_1 = once_pattern_1.findall(output[0]) | |||||
assert len(once_res_1) == 1 | |||||
once_res_2 = once_pattern_2.findall(output[0]) | |||||
assert len(once_res_2) == 1 | |||||
once_res_3 = once_pattern_3.findall(output[0]) | |||||
assert len(once_res_3) == 1 | |||||
twice_res = twice_pattern.findall(output[0]) | |||||
assert len(twice_res) == 2 | |||||
@@ -1,22 +1,22 @@ | |||||
import pytest | import pytest | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.core.callbacks import Events | |||||
from fastNLP.core.callbacks import Event | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator(): | def test_trainer_torch_without_evaluator(): | ||||
@Trainer.on(Events.on_train_epoch_begin(every=10)) | |||||
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn1(trainer): | def fn1(trainer): | ||||
pass | pass | ||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn2(trainer, batch, indices): | def fn2(trainer, batch, indices): | ||||
pass | pass | ||||
with pytest.raises(AssertionError): | |||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
with pytest.raises(BaseException): | |||||
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") | |||||
def fn3(trainer, batch): | def fn3(trainer, batch): | ||||
pass | pass | ||||
@@ -2,9 +2,7 @@ | |||||
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; | 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; | ||||
""" | """ | ||||
import pytest | import pytest | ||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
import torch.distributed as dist | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Any | from typing import Any | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
import torch.distributed as dist | |||||
@dataclass | @dataclass | ||||
class NormalClassificationTrainTorchConfig: | class NormalClassificationTrainTorchConfig: | ||||
@@ -2,9 +2,7 @@ import os.path | |||||
import subprocess | import subprocess | ||||
import sys | import sys | ||||
import pytest | import pytest | ||||
import torch.distributed as dist | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from typing import Any | from typing import Any | ||||
from pathlib import Path | from pathlib import Path | ||||
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch.distributed as dist | |||||
from torch.optim import SGD | |||||
from torch.utils.data import DataLoader | |||||
@dataclass | @dataclass | ||||
@@ -257,9 +260,9 @@ def test_trainer_on_exception( | |||||
cur_rank, | cur_rank, | ||||
n_epochs=2, | n_epochs=2, | ||||
): | ): | ||||
from fastNLP.core.callbacks.callback_events import Events | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | def raise_exception(trainer): | ||||
if trainer.driver.get_local_rank() == cur_rank: | if trainer.driver.get_local_rank() == cur_rank: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -286,6 +289,7 @@ def test_trainer_on_exception( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | @pytest.mark.parametrize("version", [0, 1, 2, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_distributed_launch_1(version): | def test_torch_distributed_launch_1(version): | ||||
@@ -1,7 +1,7 @@ | |||||
from functools import reduce | from functools import reduce | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | ||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
from tests.helpers.datasets.normal_data import NormalSampler | |||||
class Test_WrapDataLoader: | class Test_WrapDataLoader: | ||||
@@ -9,9 +9,9 @@ class Test_WrapDataLoader: | |||||
def test_normal_generator(self): | def test_normal_generator(self): | ||||
all_sanity_batches = [4, 20, 100] | all_sanity_batches = [4, 20, 100] | ||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
data = NormalIterator(num_of_data=1000) | |||||
data = NormalSampler(num_of_data=1000) | |||||
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | ||||
dataloader = iter(wrapper(dataloader=data)) | |||||
dataloader = iter(wrapper) | |||||
mark = 0 | mark = 0 | ||||
while True: | while True: | ||||
try: | try: | ||||
@@ -32,8 +32,7 @@ class Test_WrapDataLoader: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | ||||
dataloader = wrapper(dataloader) | |||||
dataloader = iter(dataloader) | |||||
dataloader = iter(wrapper) | |||||
all_supposed_running_data_num = 0 | all_supposed_running_data_num = 0 | ||||
while True: | while True: | ||||
try: | try: | ||||
@@ -55,6 +54,5 @@ class Test_WrapDataLoader: | |||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | ||||
dataloader = wrapper(dataloader) | |||||
length.append(len(dataloader)) | |||||
length.append(len(wrapper)) | |||||
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) | assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) |
@@ -15,7 +15,7 @@ else: | |||||
class Model (Module): | |||||
class Model(Module): | |||||
def __init__ (self): | def __init__ (self): | ||||
super (Model, self).__init__() | super (Model, self).__init__() | ||||
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | self.conv1 = nn.Conv (3, 32, 3, 1) # no padding | ||||
@@ -45,6 +45,7 @@ class Model (Module): | |||||
return x | return x | ||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
@pytest.mark.skip("Skip jittor tests now.") | |||||
class TestSingleDevice: | class TestSingleDevice: | ||||
def test_on_gpu_without_fp16(self): | def test_on_gpu_without_fp16(self): | ||||
@@ -2,7 +2,7 @@ import pytest | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -278,7 +278,7 @@ class TestPaddleDriverFunctions: | |||||
dataset = PaddleNormalDataset() | dataset = PaddleNormalDataset() | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset, | dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), | ||||
batch_size, | batch_size, | ||||
drop_last, | drop_last, | ||||
@@ -287,7 +287,7 @@ class TestPaddleDriverFunctions: | |||||
res = PaddleSingleDriver.get_dataloader_args(dataloader) | res = PaddleSingleDriver.get_dataloader_args(dataloader) | ||||
assert isinstance(res.dataset, PaddleNormalDataset) | assert isinstance(res.dataset, PaddleNormalDataset) | ||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler) | |||||
if shuffle: | if shuffle: | ||||
assert isinstance(res.sampler, paddle.io.RandomSampler) | assert isinstance(res.sampler, paddle.io.RandomSampler) | ||||
else: | else: | ||||
@@ -387,7 +387,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), | ||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
@@ -400,7 +400,7 @@ class TestSetDistReproDataloader: | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
else: | else: | ||||
# 此时会替换 batch_sampler | # 此时会替换 batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -414,11 +414,11 @@ class TestSetDistReproDataloader: | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) | ||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) | |||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | assert replaced_loader.batch_sampler is dist | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | ||||
@@ -450,7 +450,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=self.dataset, | dataset=self.dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), | BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), | ||||
batch_size=4, | batch_size=4, | ||||
drop_last=False, | drop_last=False, | ||||
@@ -459,7 +459,7 @@ class TestSetDistReproDataloader: | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -500,20 +500,20 @@ class TestSetDistReproDataloader: | |||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | already_seen_idx.update(batch) | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新改造 dataloader | # 重新改造 dataloader | ||||
new_loader = DataLoader( | new_loader = DataLoader( | ||||
dataset=replaced_loader.dataset, | dataset=replaced_loader.dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), | BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), | ||||
batch_size=batch_size, | batch_size=batch_size, | ||||
drop_last=False, | drop_last=False, | ||||
@@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
dataset = PaddleRandomMaxDataset(40, 10) | dataset = PaddleRandomMaxDataset(40, 10) | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) | |||||
) | ) | ||||
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") | ||||
@@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | |||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | |||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | ||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
@@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | # 2. 检查 batch_sampler 是否被正确地加载和替换 | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | assert replaced_loader.batch_sampler is dataloader.batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | ||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | ||||
@@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import ( | |||||
replace_batch_sampler, | replace_batch_sampler, | ||||
replace_sampler, | replace_sampler, | ||||
) | ) | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, | |||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
dataset = PaddleNormalDataset(10) | dataset = PaddleNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | assert isinstance(replaced_loader.dataset, PaddleNormalDataset) | ||||
assert len(replaced_loader.dataset) == len(dataset) | assert len(replaced_loader.dataset) == len(dataset) | ||||
assert replaced_loader.batch_sampler.batch_size == 16 | assert replaced_loader.batch_sampler.batch_size == 16 | ||||
@@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
import torch | |||||
import torch.distributed as dist | |||||
from torch.utils.data import DataLoader, BatchSampler | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): | |||||
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): | |||||
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) | ||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | ||||
device = [torch.device(i) for i in device] | device = [torch.device(i) for i in device] | ||||
@@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
def test_multi_drivers(): | |||||
""" | |||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
""" | |||||
generate_driver(10, 10) | |||||
generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestDDPDriverFunction: | class TestDDPDriverFunction: | ||||
""" | """ | ||||
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | ||||
""" | """ | ||||
@classmethod | |||||
def setup_class(cls): | |||||
cls.driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_multi_drivers(self): | |||||
def test_simple_functions(self): | |||||
""" | """ | ||||
测试使用了多个 TorchDDPDriver 的情况。 | |||||
简单测试多个函数 | |||||
""" | """ | ||||
driver2 = generate_driver(20, 10) | |||||
with pytest.raises(RuntimeError): | |||||
# 设备设置不同,应该报错 | |||||
driver3 = generate_driver(20, 3, device=[0,1,2]) | |||||
assert False | |||||
dist.barrier() | |||||
driver = generate_driver(10, 10) | |||||
@magic_argv_env_context | |||||
def test_move_data_to_device(self): | |||||
""" | """ | ||||
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 | |||||
就不重复测试了 | |||||
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 | |||||
tests/core/utils/test_torch_utils.py中,就不重复测试了 | |||||
""" | """ | ||||
self.driver.move_data_to_device(torch.rand((32, 64))) | |||||
driver.move_data_to_device(torch.rand((32, 64))) | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_is_distributed(self): | |||||
""" | """ | ||||
测试 is_distributed 函数 | 测试 is_distributed 函数 | ||||
""" | """ | ||||
assert self.driver.is_distributed() == True | |||||
assert driver.is_distributed() == True | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_get_no_sync_context(self): | |||||
""" | """ | ||||
测试 get_no_sync_context 函数 | 测试 get_no_sync_context 函数 | ||||
""" | """ | ||||
res = self.driver.get_model_no_sync_context() | |||||
res = driver.get_model_no_sync_context() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_is_global_zero(self): | |||||
""" | """ | ||||
测试 is_global_zero 函数 | 测试 is_global_zero 函数 | ||||
""" | """ | ||||
self.driver.is_global_zero() | |||||
driver.is_global_zero() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_unwrap_model(self): | |||||
""" | """ | ||||
测试 unwrap_model 函数 | 测试 unwrap_model 函数 | ||||
""" | """ | ||||
self.driver.unwrap_model() | |||||
driver.unwrap_model() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_get_local_rank(self): | |||||
""" | """ | ||||
测试 get_local_rank 函数 | 测试 get_local_rank 函数 | ||||
""" | """ | ||||
self.driver.get_local_rank() | |||||
driver.get_local_rank() | |||||
dist.barrier() | dist.barrier() | ||||
@magic_argv_env_context | |||||
def test_all_gather(self): | |||||
""" | """ | ||||
测试 all_gather 函数 | 测试 all_gather 函数 | ||||
详细的测试在 test_dist_utils.py 中完成 | 详细的测试在 test_dist_utils.py 中完成 | ||||
""" | """ | ||||
obj = { | obj = { | ||||
"rank": self.driver.global_rank | |||||
"rank": driver.global_rank | |||||
} | } | ||||
obj_list = self.driver.all_gather(obj, group=None) | |||||
obj_list = driver.all_gather(obj, group=None) | |||||
for i, res in enumerate(obj_list): | for i, res in enumerate(obj_list): | ||||
assert res["rank"] == i | assert res["rank"] == i | ||||
@magic_argv_env_context | |||||
@pytest.mark.parametrize("src_rank", ([0, 1])) | |||||
def test_broadcast_object(self, src_rank): | |||||
""" | """ | ||||
测试 broadcast_object 函数 | 测试 broadcast_object 函数 | ||||
详细的函数在 test_dist_utils.py 中完成 | 详细的函数在 test_dist_utils.py 中完成 | ||||
""" | """ | ||||
if self.driver.global_rank == src_rank: | |||||
if driver.global_rank == 0: | |||||
obj = { | obj = { | ||||
"rank": self.driver.global_rank | |||||
"rank": driver.global_rank | |||||
} | } | ||||
else: | else: | ||||
obj = None | obj = None | ||||
res = self.driver.broadcast_object(obj, src=src_rank) | |||||
assert res["rank"] == src_rank | |||||
res = driver.broadcast_object(obj, src=0) | |||||
assert res["rank"] == 0 | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||
@@ -187,7 +180,6 @@ class TestSetDistReproDataloader: | |||||
@classmethod | @classmethod | ||||
def setup_class(cls): | def setup_class(cls): | ||||
cls.device = [0, 1] | cls.device = [0, 1] | ||||
cls.driver = generate_driver(10, 10, device=cls.device) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchNormalDataset(40) | self.dataset = TorchNormalDataset(40) | ||||
@@ -204,17 +196,20 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 | ||||
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | ||||
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
assert replaced_loader.batch_sampler is batch_sampler | assert replaced_loader.batch_sampler is batch_sampler | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -223,9 +218,10 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 | ||||
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) | ||||
sampler = RandomSampler(self.dataset, shuffle=shuffle) | sampler = RandomSampler(self.dataset, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -234,9 +230,11 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler is sampler | assert replaced_loader.batch_sampler.sampler is sampler | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` | ||||
@@ -251,15 +249,17 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 | ||||
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) | ||||
with pytest.raises(RuntimeError): | with pytest.raises(RuntimeError): | ||||
# 应当抛出 RuntimeError | # 应当抛出 RuntimeError | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
# @pytest.mark.parametrize("shuffle", ([True, False])) | |||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): | ||||
""" | """ | ||||
@@ -268,21 +268,24 @@ class TestSetDistReproDataloader: | |||||
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler | ||||
和原 dataloader 相同 | 和原 dataloader 相同 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == 4 | assert replaced_loader.batch_sampler.batch_size == 4 | ||||
self.check_distributed_sampler(dataloader.batch_sampler) | self.check_distributed_sampler(dataloader.batch_sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -292,12 +295,13 @@ class TestSetDistReproDataloader: | |||||
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 | ||||
batch_sampler.sampler 和原 dataloader 相同 | batch_sampler.sampler 和原 dataloader 相同 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | ) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -307,9 +311,11 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.batch_size == 4 | assert replaced_loader.batch_sampler.batch_size == 4 | ||||
assert replaced_loader.batch_sampler.drop_last == False | assert replaced_loader.batch_sampler.drop_last == False | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | |||||
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) | |||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -318,11 +324,14 @@ class TestSetDistReproDataloader: | |||||
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 | ||||
此时直接返回原来的 dataloader,不做任何处理。 | 此时直接返回原来的 dataloader,不做任何处理。 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) | |||||
assert replaced_loader is dataloader | assert replaced_loader is dataloader | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | ||||
@@ -337,12 +346,13 @@ class TestSetDistReproDataloader: | |||||
的表现 | 的表现 | ||||
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=self.dataset, | dataset=self.dataset, | ||||
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) | ||||
) | ) | ||||
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) | ||||
@@ -351,6 +361,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -361,8 +373,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
@@ -372,6 +385,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -381,8 +396,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -392,6 +408,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
""" | """ | ||||
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 | ||||
@@ -407,8 +425,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -418,6 +437,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | assert replaced_loader.batch_sampler.sampler.shuffle == shuffle | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -427,8 +448,9 @@ class TestSetDistReproDataloader: | |||||
的表现 | 的表现 | ||||
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -439,6 +461,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("shuffle", ([True, False])) | @pytest.mark.parametrize("shuffle", ([True, False])) | ||||
@@ -448,8 +472,9 @@ class TestSetDistReproDataloader: | |||||
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 | ||||
的属性 | 的属性 | ||||
""" | """ | ||||
driver = generate_driver(10, 10, device=self.device) | |||||
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) | |||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler, BatchSampler) | ||||
@@ -459,6 +484,8 @@ class TestSetDistReproDataloader: | |||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) | ||||
dist.barrier() | dist.barrier() | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
def check_distributed_sampler(self, sampler): | def check_distributed_sampler(self, sampler): | ||||
""" | """ | ||||
@@ -469,7 +496,7 @@ class TestSetDistReproDataloader: | |||||
if not isinstance(sampler, UnrepeatedSampler): | if not isinstance(sampler, UnrepeatedSampler): | ||||
assert sampler.pad == True | assert sampler.pad == True | ||||
def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): | |||||
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): | |||||
""" | """ | ||||
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 | ||||
""" | """ | ||||
@@ -501,8 +528,8 @@ class TestSetDistReproDataloader: | |||||
drop_last=False, | drop_last=False, | ||||
) | ) | ||||
new_loader.batch_sampler.set_distributed( | new_loader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank, | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
new_loader.batch_sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.load_state_dict(sampler_states) | ||||
@@ -512,8 +539,8 @@ class TestSetDistReproDataloader: | |||||
# 重新构造 dataloader | # 重新构造 dataloader | ||||
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) | ||||
new_loader.batch_sampler.sampler.set_distributed( | new_loader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver.world_size, | |||||
rank=self.driver.global_rank | |||||
num_replicas=driver.world_size, | |||||
rank=driver.global_rank | |||||
) | ) | ||||
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | new_loader.batch_sampler.sampler.load_state_dict(sampler_states) | ||||
for idx, batch in enumerate(new_loader): | for idx, batch in enumerate(new_loader): | ||||
@@ -534,11 +561,6 @@ class TestSaveLoad: | |||||
测试多卡情况下 save 和 load 相关函数的表现 | 测试多卡情况下 save 和 load 相关函数的表现 | ||||
""" | """ | ||||
@classmethod | |||||
def setup_class(cls): | |||||
# 不在这里 setup 的话会报错 | |||||
cls.driver = generate_driver(10, 10) | |||||
def setup_method(self): | def setup_method(self): | ||||
self.dataset = TorchArgMaxDataset(10, 20) | self.dataset = TorchArgMaxDataset(10, 20) | ||||
@@ -552,26 +574,26 @@ class TestSaveLoad: | |||||
path = "model" | path = "model" | ||||
dataloader = DataLoader(self.dataset, batch_size=2) | dataloader = DataLoader(self.dataset, batch_size=2) | ||||
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) | |||||
self.driver1.save_model(path, only_state_dict) | |||||
driver1.save_model(path, only_state_dict) | |||||
# 同步 | # 同步 | ||||
dist.barrier() | dist.barrier() | ||||
self.driver2.load_model(path, only_state_dict) | |||||
driver2.load_model(path, only_state_dict) | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
batch = self.driver1.move_data_to_device(batch) | |||||
res1 = self.driver1.model( | |||||
batch = driver1.move_data_to_device(batch) | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -580,6 +602,9 @@ class TestSaveLoad: | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
@@ -593,7 +618,7 @@ class TestSaveLoad: | |||||
path = "model.ckp" | path = "model.ckp" | ||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ | |||||
generate_driver(10, 10, device=device, fp16=False) | generate_driver(10, 10, device=device, fp16=False) | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
self.dataset, | self.dataset, | ||||
@@ -603,8 +628,8 @@ class TestSaveLoad: | |||||
drop_last=False | drop_last=False | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
@@ -623,7 +648,7 @@ class TestSaveLoad: | |||||
# 保存状态 | # 保存状态 | ||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
@@ -634,11 +659,11 @@ class TestSaveLoad: | |||||
drop_last=False | drop_last=False | ||||
) | ) | ||||
dataloader.batch_sampler.set_distributed( | dataloader.batch_sampler.set_distributed( | ||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -652,7 +677,7 @@ class TestSaveLoad: | |||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
if fp16: | if fp16: | ||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
# 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
@@ -664,16 +689,16 @@ class TestSaveLoad: | |||||
left_x_batches.update(batch["x"]) | left_x_batches.update(batch["x"]) | ||||
left_y_batches.update(batch["y"]) | left_y_batches.update(batch["y"]) | ||||
res1 = self.driver1.model( | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -686,6 +711,9 @@ class TestSaveLoad: | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
@@ -700,13 +728,13 @@ class TestSaveLoad: | |||||
num_replicas = len(device) | num_replicas = len(device) | ||||
self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
self.driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
driver1 = generate_driver(10, 10, device=device, fp16=fp16) | |||||
driver2 = generate_driver(10, 10, device=device, fp16=False) | |||||
dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver1.world_size, | |||||
rank=self.driver1.global_rank, | |||||
num_replicas=driver1.world_size, | |||||
rank=driver1.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
num_consumed_batches = 2 | num_consumed_batches = 2 | ||||
@@ -726,18 +754,18 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
dataloader.batch_sampler.sampler.set_distributed( | dataloader.batch_sampler.sampler.set_distributed( | ||||
num_replicas=self.driver2.world_size, | |||||
rank=self.driver2.global_rank, | |||||
num_replicas=driver2.world_size, | |||||
rank=driver2.global_rank, | |||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -753,7 +781,7 @@ class TestSaveLoad: | |||||
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] | ||||
# 3. 检查 fp16 是否被加载 | # 3. 检查 fp16 是否被加载 | ||||
if fp16: | if fp16: | ||||
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) | |||||
# 4. 检查 model 的参数是否正确 | # 4. 检查 model 的参数是否正确 | ||||
# 5. 检查 batch_idx | # 5. 检查 batch_idx | ||||
@@ -765,16 +793,16 @@ class TestSaveLoad: | |||||
left_x_batches.update(batch["x"]) | left_x_batches.update(batch["x"]) | ||||
left_y_batches.update(batch["y"]) | left_y_batches.update(batch["y"]) | ||||
res1 = self.driver1.model( | |||||
res1 = driver1.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver1.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver1.model.module.model.evaluate_step, | |||||
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model | # Driver.model -> DataParallel.module -> _FleetWrappingModel.model | ||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
res2 = self.driver2.model( | |||||
res2 = driver2.model( | |||||
batch, | batch, | ||||
fastnlp_fn=self.driver2.model.module.model.evaluate_step, | |||||
fastnlp_fn=driver2.model.module.model.evaluate_step, | |||||
fastnlp_signature_fn=None, | fastnlp_signature_fn=None, | ||||
wo_auto_param_call=False, | wo_auto_param_call=False, | ||||
) | ) | ||||
@@ -786,4 +814,7 @@ class TestSaveLoad: | |||||
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas | ||||
finally: | finally: | ||||
rank_zero_rm(path) | |||||
rank_zero_rm(path) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() |
@@ -2,12 +2,14 @@ import pytest | |||||
from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver | ||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | ||||
from fastNLP.envs import get_gpu_count | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
import torch | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import device as torchdevice | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
@@ -20,7 +22,7 @@ def test_incorrect_driver(): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "cuda:0", 0, torch.device("cuda:0")] | |||||
["cpu", "cuda:0", 0, torchdevice("cuda:0")] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
@@ -83,7 +85,6 @@ def test_get_ddp(driver, device): | |||||
("driver", "device"), | ("driver", "device"), | ||||
[("torch_ddp", "cpu")] | [("torch_ddp", "cpu")] | ||||
) | ) | ||||
@magic_argv_env_context | |||||
def test_get_ddp_cpu(driver, device): | def test_get_ddp_cpu(driver, device): | ||||
""" | """ | ||||
测试试图在 cpu 上初始化分布式训练的情况 | 测试试图在 cpu 上初始化分布式训练的情况 | ||||
@@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] | |||||
[-2, [0, 20, 3], [-2], 20] | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | ["torch", "torch_ddp"] | ||||
) | ) | ||||
@magic_argv_env_context | |||||
def test_device_out_of_range(driver, device): | def test_device_out_of_range(driver, device): | ||||
""" | """ | ||||
测试传入的device超过范围的情况 | 测试传入的device超过范围的情况 | ||||
@@ -2,7 +2,7 @@ import pytest | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
@@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE: | |||||
def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | ||||
""" | """ | ||||
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader | |||||
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader | |||||
""" | """ | ||||
if shuffle: | if shuffle: | ||||
sampler = torch.utils.data.RandomSampler(dataset) | sampler = torch.utils.data.RandomSampler(dataset) | ||||
@@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): | |||||
sampler = torch.utils.data.SequentialSampler(dataset) | sampler = torch.utils.data.SequentialSampler(dataset) | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=RandomBatchSampler( | |||||
batch_sampler=ReproduceBatchSampler( | |||||
BatchSampler( | BatchSampler( | ||||
sampler, batch_size=batch_size, drop_last=drop_last | sampler, batch_size=batch_size, drop_last=drop_last | ||||
), | ), | ||||
@@ -306,7 +306,7 @@ class TestTorchDriverFunctions: | |||||
res = TorchSingleDriver.get_dataloader_args(dataloader) | res = TorchSingleDriver.get_dataloader_args(dataloader) | ||||
assert isinstance(res.dataset, TorchNormalDataset) | assert isinstance(res.dataset, TorchNormalDataset) | ||||
assert isinstance(res.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(res.batch_sampler, ReproduceBatchSampler) | |||||
if shuffle: | if shuffle: | ||||
assert isinstance(res.sampler, torch.utils.data.RandomSampler) | assert isinstance(res.sampler, torch.utils.data.RandomSampler) | ||||
else: | else: | ||||
@@ -401,7 +401,7 @@ class TestSetDistReproDataloader: | |||||
""" | """ | ||||
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 | ||||
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), | ||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler | |||||
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler | |||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) | ||||
@@ -414,7 +414,7 @@ class TestSetDistReproDataloader: | |||||
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) | ||||
else: | else: | ||||
# 此时会替换 batch_sampler | # 此时会替换 batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -428,11 +428,11 @@ class TestSetDistReproDataloader: | |||||
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler | ||||
""" | """ | ||||
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) | ||||
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler is dist | assert replaced_loader.batch_sampler is dist | ||||
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) | ||||
@@ -466,7 +466,7 @@ class TestSetDistReproDataloader: | |||||
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) | ||||
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size | ||||
assert replaced_loader.drop_last == dataloader.drop_last | assert replaced_loader.drop_last == dataloader.drop_last | ||||
@@ -502,14 +502,14 @@ class TestSetDistReproDataloader: | |||||
if idx >= num_consumed_batches: | if idx >= num_consumed_batches: | ||||
break | break | ||||
already_seen_idx.update(batch) | already_seen_idx.update(batch) | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
sampler_states = replaced_loader.batch_sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.state_dict() | ||||
else: | else: | ||||
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | sampler_states = replaced_loader.batch_sampler.sampler.state_dict() | ||||
# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range | ||||
left_idxes = set() | left_idxes = set() | ||||
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): | |||||
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): | |||||
batch_size = replaced_loader.batch_sampler.batch_size | batch_size = replaced_loader.batch_sampler.batch_size | ||||
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size | ||||
# 重新改造 dataloader | # 重新改造 dataloader | ||||
@@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# 2. 检查 batch_sampler 是否被正确地加载和替换 | # 2. 检查 batch_sampler 是否被正确地加载和替换 | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert replaced_loader.batch_sampler is dataloader.batch_sampler | assert replaced_loader.batch_sampler is dataloader.batch_sampler | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] | ||||
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 | ||||
@@ -30,7 +30,7 @@ class SequenceDataSet: | |||||
def check_replace_sampler(driver): | def check_replace_sampler(driver): | ||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler | |||||
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler | |||||
# reproducible 是 True 和 False | # reproducible 是 True 和 False | ||||
# 需要 check 返回的 sampler 和 dataloader 都不同了 | # 需要 check 返回的 sampler 和 dataloader 都不同了 | ||||
@@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
replace_batch_sampler, | replace_batch_sampler, | ||||
replace_sampler, | replace_sampler, | ||||
) | ) | ||||
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler | |||||
from torch.utils.data import DataLoader, BatchSampler | from torch.utils.data import DataLoader, BatchSampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
dataset = TorchNormalDataset(10) | dataset = TorchNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) | |||||
replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | replaced_loader = replace_batch_sampler(dataloader, batch_sampler) | ||||
assert not (replaced_loader is dataloader) | assert not (replaced_loader is dataloader) | ||||
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) | |||||
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) | |||||
assert isinstance(replaced_loader.dataset, TorchNormalDataset) | assert isinstance(replaced_loader.dataset, TorchNormalDataset) | ||||
assert len(replaced_loader.dataset) == len(dataset) | assert len(replaced_loader.dataset) == len(dataset) | ||||
assert replaced_loader.batch_sampler.batch_size == 16 | assert replaced_loader.batch_sampler.batch_size == 16 | ||||
@@ -7,15 +7,20 @@ import copy | |||||
import socket | import socket | ||||
import pytest | import pytest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
from sklearn.metrics import accuracy_score as sklearn_accuracy | from sklearn.metrics import accuracy_score as sklearn_accuracy | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.metrics.accuracy import Accuracy | from fastNLP.core.metrics.accuracy import Accuracy | ||||
from fastNLP.core.metrics.metric import Metric | from fastNLP.core.metrics.metric import Metric | ||||
from .utils import find_free_network_port, setup_ddp, _assert_allclose | from .utils import find_free_network_port, setup_ddp, _assert_allclose | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
@@ -26,7 +31,7 @@ pool = None | |||||
def _test(local_rank: int, | def _test(local_rank: int, | ||||
world_size: int, | world_size: int, | ||||
device: torch.device, | |||||
device: "torch.device", | |||||
dataset: DataSet, | dataset: DataSet, | ||||
metric_class: Type[Metric], | metric_class: Type[Metric], | ||||
metric_kwargs: Dict[str, Any], | metric_kwargs: Dict[str, Any], | ||||
@@ -2,18 +2,23 @@ from functools import partial | |||||
import copy | import copy | ||||
import pytest | import pytest | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
from torch.multiprocessing import Pool, set_start_method | |||||
from fastNLP.core.metrics import ClassifyFPreRecMetric | from fastNLP.core.metrics import ClassifyFPreRecMetric | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from .utils import find_free_network_port, setup_ddp | from .utils import find_free_network_port, setup_ddp | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
def _test(local_rank: int, world_size: int, device: torch.device, | |||||
def _test(local_rank: int, world_size: int, device: "torch.device", | |||||
dataset: DataSet, metric_class, metric_kwargs, metric_result): | dataset: DataSet, metric_class, metric_kwargs, metric_result): | ||||
metric = metric_class(**metric_kwargs) | metric = metric_class(**metric_kwargs) | ||||
# dataset 也类似(每个进程有自己的一个) | # dataset 也类似(每个进程有自己的一个) | ||||
@@ -5,16 +5,21 @@ import os, sys | |||||
import copy | import copy | ||||
from functools import partial | from functools import partial | ||||
import torch | |||||
import torch.distributed | |||||
import numpy as np | import numpy as np | ||||
import socket | import socket | ||||
from torch.multiprocessing import Pool, set_start_method | |||||
# from multiprocessing import Pool, set_start_method | # from multiprocessing import Pool, set_start_method | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from .utils import find_free_network_port, setup_ddp | from .utils import find_free_network_port, setup_ddp | ||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
import torch.distributed | |||||
from torch.multiprocessing import Pool, set_start_method | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method | |||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
@@ -44,7 +49,7 @@ pool = None | |||||
def _test(local_rank: int, | def _test(local_rank: int, | ||||
world_size: int, | world_size: int, | ||||
device: torch.device, | |||||
device: "torch.device", | |||||
dataset: DataSet, | dataset: DataSet, | ||||
metric_class, | metric_class, | ||||
metric_kwargs, | metric_kwargs, | ||||
@@ -2,9 +2,11 @@ import os, sys | |||||
import socket | import socket | ||||
from typing import Union | from typing import Union | ||||
import torch | |||||
from torch import distributed | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch import distributed | |||||
def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | ||||
@@ -1,161 +1,131 @@ | |||||
from array import array | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from array import array | |||||
from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler | |||||
class TestReproducibleBatchSampler: | |||||
def test_1(self): | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||||
forward_steps = 3 | |||||
iterator = iter(reproduce_batch_sampler) | |||||
i = 0 | |||||
while i < forward_steps: | |||||
next(iterator) | |||||
i += 1 | |||||
# 保存状态; | |||||
state = reproduce_batch_sampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), | |||||
"num_consumed_samples": forward_steps * 4, | |||||
"sampler_type": "ReproduceBatchSampler"} | |||||
# 重新生成一个 batchsampler 然后加载状态; | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
real_res = [] | |||||
supposed_res = (list(range(12, 16)), list(range(16, 20))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert supposed_res[i] == real_res[i] | |||||
# 改变 batchsize; | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
real_res = [] | |||||
supposed_res = (list(range(12, 19)), list(range(19, 26))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert supposed_res[i] == real_res[i] | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 26 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# | |||||
# class TestReproducibleBatchSampler: | |||||
# # TODO 拆分测试,在这里只测试一个东西 | |||||
# def test_torch_dataloader_1(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# next(iter_dataloader) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
# "sampler_type": "RandomBatchSampler"} | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 改变 batch_size; | |||||
# after_batch_size = 3 | |||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# # 先把断点重训所在的那一个 epoch 跑完; | |||||
# begin_idx = 27 | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# # 开始新的一轮; | |||||
# begin_idx = 0 | |||||
# iter_dataloader = iter(dataloader) | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_torch_dataloader_2(self): | |||||
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
# all_supposed_data = [] | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, RandomBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 先把这一轮的数据过完; | |||||
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
# while True: | |||||
# try: | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# except StopIteration: | |||||
# break | |||||
# assert all_supposed_data == list(pre_index_list) | |||||
# | |||||
# # 重新开启新的一轮; | |||||
# for _ in range(3): | |||||
# iter_dataloader = iter(dataloader) | |||||
# res = [] | |||||
# while True: | |||||
# try: | |||||
# res.append(next(iter_dataloader)) | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_3(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# | |||||
# for idx, data in enumerate(dataloader): | |||||
# if idx > 3: | |||||
# break | |||||
# | |||||
# iterator = iter(dataloader) | |||||
# for each in iterator: | |||||
# pass | |||||
def test_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
before_batch_size = 7 | |||||
sampler = NormalSampler(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader)) | |||||
# 1. 保存状态 | |||||
state = reproduce_batch_sampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
sampler = NormalSampler(num_of_data=100, shuffle=True) | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"] | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.extend(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
assert res != all_supposed_data | |||||
class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
@@ -511,3 +481,313 @@ class TestBucketedBatchSampler: | |||||
already_seen_set.update(batch) | already_seen_set.update(batch) | ||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) | ||||
class TestRandomBatchSampler: | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) | |||||
def test_single_num_batch(self, shuffle, drop_last, num): | |||||
# 数量不够不报错 | |||||
for num in [2, 7, 14, 15, 70, 71]: | |||||
dataset = DatasetWithVaryLength(num_of_data=num) | |||||
before_batch_size = 7 | |||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
count = len(list(iter(re_batchsampler))) | |||||
if drop_last: | |||||
assert count==num//before_batch_size, num | |||||
else: | |||||
assert count==(num+before_batch_size-1)//before_batch_size, num | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
def test_single(self, shuffle, drop_last): | |||||
before_batch_size = 7 | |||||
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler.set_epoch(0) | |||||
forward_steps = 10 | |||||
iterator = iter(re_batchsampler) | |||||
already_generate_indices = set() | |||||
for _ in range(forward_steps): | |||||
batch = next(iterator) | |||||
already_generate_indices.update(batch) | |||||
# 1. 保存状态 | |||||
state = re_batchsampler.state_dict() | |||||
# 2. 断点重训,继续训练 | |||||
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler2.load_state_dict(state) | |||||
re_batchsampler2.set_epoch(0) | |||||
new_already_generate_indices = set() | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-before_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler2: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
new_already_generate_indices.update(batch) | |||||
if drop_last is False: | |||||
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler3.load_state_dict(state) | |||||
re_batchsampler3.set_epoch(0) | |||||
count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in re_batchsampler3: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
# 再 save ,不允许再上个epoch没结束继续sample | |||||
after_batch_size = 5 | |||||
with pytest.raises(RuntimeError): | |||||
state = re_batchsampler3.state_dict() | |||||
for batch in re_batchsampler3: # consume all, 这样才能save | |||||
pass | |||||
already_generate_indices = set() | |||||
count = 0 | |||||
for batch in re_batchsampler3: # 重新开始 | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
state = re_batchsampler3.state_dict() | |||||
# 这里的 drop_last 为 False,需要最终是所有 sample | |||||
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
drop_last=False, | |||||
shuffle=shuffle) | |||||
re_batchsampler4.load_state_dict(state) | |||||
re_batchsampler4.set_epoch(0) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
for batch in re_batchsampler4: | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
assert len(already_generate_indices) == len(dataset) | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
def test_multi(self, shuffle, drop_last, pad): | |||||
# def test_multi(self, shuffle=True, drop_last=False, pad=False): | |||||
# no shuffle | |||||
num_replica = 2 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
batch_size = 5 | |||||
num_batch_per_bucket = 10 | |||||
lengths = [] | |||||
rank0_already_seen_indexes = None | |||||
max_diff = num_batch_per_bucket * batch_size * num_replica | |||||
for rank in range(num_replica): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
already_seen_indexes = set() | |||||
repeat_count = 0 | |||||
for batch in sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if rank0_already_seen_indexes: # 不能交叉出现 | |||||
assert b not in rank0_already_seen_indexes | |||||
already_seen_indexes.update(batch) | |||||
if rank0_already_seen_indexes is None: | |||||
rank0_already_seen_indexes = already_seen_indexes | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count<=1 | |||||
else: | |||||
assert repeat_count==0 | |||||
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 | |||||
# 多进程的保存 | |||||
already_seen_indexes = set() | |||||
for rank in range(num_replica): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_indexes.update(batch) | |||||
if count>5: | |||||
break | |||||
count += 1 | |||||
state = sampler.state_dict() | |||||
# 切换成单机 | |||||
new_batch_size = 6 | |||||
num_batch_per_bucket = 3 | |||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.load_state_dict(state) | |||||
repeat_count = 0 | |||||
new_already_seen_indexes = set(list(already_seen_indexes)) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in new_sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in new_already_seen_indexes) | |||||
new_already_seen_indexes.update(batch) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
if drop_last is False: # 如果没有drop应该相等 | |||||
assert len(new_already_seen_indexes)==len(dataset) | |||||
# 测试替换卡的数量。 | |||||
num_replica = 3 | |||||
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.load_state_dict(state) | |||||
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) | |||||
repeat_count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
for batch in new_sampler: | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
batch_size = 6 | |||||
if num_replicas*batch_size > num_samples: | |||||
return | |||||
num_batch_per_bucket = 10 | |||||
samplers = [] | |||||
lengths = [] | |||||
for i in range(num_replicas): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | |||||
samplers.append(sampler) | |||||
lengths.append(len(list(iter(sampler)))) | |||||
assert len(set(lengths))==1 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replicas', [1, 2, 3]) | |||||
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): | |||||
""" | |||||
测试是否能够正确地恢复使用过的(forward)数据 | |||||
:return: | |||||
""" | |||||
batch_size = 6 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
samplers = [] | |||||
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) | |||||
for i in range(num_replicas): | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) | |||||
samplers.append(sampler) | |||||
count = 0 | |||||
already_seen_sets = [set()] | |||||
already_seen_set = set() | |||||
for batchs in zip(*samplers): | |||||
batch = chain(*batchs) | |||||
already_seen_set.update(batch) | |||||
already_seen_sets.append(deepcopy(already_seen_set)) | |||||
count += 1 | |||||
if count > 3: | |||||
break | |||||
states = samplers[0].state_dict() | |||||
for i in range(len(already_seen_sets)): | |||||
states['num_consumed_samples'] = num_consumed_samples_array[i] | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
already_seen_set = deepcopy(already_seen_sets[i]) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( | |||||
dataset) | |||||
# 测试保存之后再次保存 | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
states['num_consumed_samples'] = num_consumed_samples_array[2] | |||||
if len(already_seen_sets)<3: | |||||
return | |||||
already_seen_set = already_seen_sets[2] | |||||
count = 0 | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
count += 1 | |||||
if count > 6: | |||||
break | |||||
states = sampler.state_dict() | |||||
num_consumed_samples_array = list(range(len(dataset))) | |||||
states['num_consumed_samples'] = num_consumed_samples_array[count] | |||||
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, | |||||
shuffle=shuffle, | |||||
drop_last=drop_last) | |||||
sampler.load_state_dict(states) | |||||
sampler.set_epoch(0) | |||||
for batch in sampler: | |||||
already_seen_set.update(batch) | |||||
assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) |
@@ -0,0 +1,141 @@ | |||||
from array import array | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
import pytest | |||||
from fastNLP.core.samplers import ReproduceBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
@pytest.mark.torch | |||||
class TestReproducibleBatchSamplerTorch: | |||||
def test_torch_dataloader_1(self): | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproduceBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
iter_dataloader = iter(dataloader) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert res != all_supposed_data | |||||
@@ -3,6 +3,7 @@ import pytest | |||||
import subprocess | import subprocess | ||||
from io import StringIO | from io import StringIO | ||||
import sys | import sys | ||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | |||||
from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import pytest | |||||
from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
@@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): | |||||
filepath = None | filepath = None | ||||
try: | try: | ||||
with Capturing() as output: | with Capturing() as output: | ||||
dump_fastnlp_backend() | |||||
dump_fastnlp_backend(backend="torch") | |||||
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') | ||||
assert filepath in output[0] | assert filepath in output[0] | ||||
assert os.path.exists(filepath) | assert os.path.exists(filepath) | ||||
@@ -1,7 +1,9 @@ | |||||
import torch | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from fastNLP.core.callbacks.callback import Callback | from fastNLP.core.callbacks.callback import Callback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
class RecordAccumulationStepsCallback_Torch(Callback): | class RecordAccumulationStepsCallback_Torch(Callback): | ||||
@@ -1,13 +1,25 @@ | |||||
import numpy as np | import numpy as np | ||||
import random | |||||
class NormalIterator: | |||||
def __init__(self, num_of_data=1000): | |||||
class NormalSampler: | |||||
def __init__(self, num_of_data=1000, shuffle=False): | |||||
self._num_of_data = num_of_data | self._num_of_data = num_of_data | ||||
self._data = list(range(num_of_data)) | self._data = list(range(num_of_data)) | ||||
if shuffle: | |||||
random.shuffle(self._data) | |||||
self.shuffle = shuffle | |||||
self._index = 0 | self._index = 0 | ||||
self.need_reinitialize = False | |||||
def __iter__(self): | def __iter__(self): | ||||
if self.need_reinitialize: | |||||
self._index = 0 | |||||
if self.shuffle: | |||||
random.shuffle(self._data) | |||||
else: | |||||
self.need_reinitialize = True | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
@@ -15,12 +27,45 @@ class NormalIterator: | |||||
raise StopIteration | raise StopIteration | ||||
_data = self._data[self._index] | _data = self._data[self._index] | ||||
self._index += 1 | self._index += 1 | ||||
return self._data | |||||
return _data | |||||
def __len__(self): | def __len__(self): | ||||
return self._num_of_data | return self._num_of_data | ||||
class NormalBatchSampler: | |||||
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None: | |||||
# Since collections.abc.Iterable does not check for `__getitem__`, which | |||||
# is one way for an object to be an iterable, we don't do an `isinstance` | |||||
# check here. | |||||
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ | |||||
batch_size <= 0: | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got batch_size={}".format(batch_size)) | |||||
if not isinstance(drop_last, bool): | |||||
raise ValueError("drop_last should be a boolean value, but got " | |||||
"drop_last={}".format(drop_last)) | |||||
self.sampler = sampler | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
def __iter__(self): | |||||
batch = [] | |||||
for idx in self.sampler: | |||||
batch.append(idx) | |||||
if len(batch) == self.batch_size: | |||||
yield batch | |||||
batch = [] | |||||
if len(batch) > 0 and not self.drop_last: | |||||
yield batch | |||||
def __len__(self) -> int: | |||||
if self.drop_last: | |||||
return len(self.sampler) // self.batch_size | |||||
else: | |||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size | |||||
class RandomDataset: | class RandomDataset: | ||||
def __init__(self, num_data=10): | def __init__(self, num_data=10): | ||||
self.data = np.random.rand(num_data) | self.data = np.random.rand(num_data) | ||||
@@ -29,4 +74,7 @@ class RandomDataset: | |||||
return len(self.data) | return len(self.data) | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self.data[item] | |||||
return self.data[item] | |||||
@@ -1,7 +1,11 @@ | |||||
import torch | import torch | ||||
from functools import reduce | from functools import reduce | ||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import Dataset, DataLoader, DistributedSampler | |||||
from torch.utils.data.sampler import SequentialSampler, BatchSampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
class TorchNormalDataset(Dataset): | class TorchNormalDataset(Dataset): | ||||
@@ -1,9 +1,14 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.nn import Module | |||||
import torch.nn as nn | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
# 1. 最为基础的分类模型 | # 1. 最为基础的分类模型 | ||||
class TorchNormalModel_Classification_1(nn.Module): | |||||
class TorchNormalModel_Classification_1(Module): | |||||
""" | """ | ||||
单独实现 train_step 和 evaluate_step; | 单独实现 train_step 和 evaluate_step; | ||||
""" | """ | ||||
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||||
return {"preds": x, "target": y} | return {"preds": x, "target": y} | ||||
class TorchNormalModel_Classification_2(nn.Module): | |||||
class TorchNormalModel_Classification_2(Module): | |||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
""" | """ | ||||
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): | |||||
return {"loss": loss, "preds": x, "target": y} | return {"loss": loss, "preds": x, "target": y} | ||||
class TorchNormalModel_Classification_3(nn.Module): | |||||
class TorchNormalModel_Classification_3(Module): | |||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
关闭 auto_param_call,forward 只有一个 batch 参数; | 关闭 auto_param_call,forward 只有一个 batch 参数; | ||||
@@ -0,0 +1,6 @@ | |||||
[pytest] | |||||
markers = | |||||
torch | |||||
paddle | |||||
jittor | |||||
torchpaddle |