Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
efa3d5451b
73 changed files with 2903 additions and 1395 deletions
  1. +59
    -10
      fastNLP/core/__init__.py
  2. +2
    -3
      fastNLP/core/callbacks/__init__.py
  3. +27
    -29
      fastNLP/core/callbacks/callback.py
  4. +499
    -0
      fastNLP/core/callbacks/callback_event.py
  5. +0
    -206
      fastNLP/core/callbacks/callback_events.py
  6. +8
    -7
      fastNLP/core/callbacks/callback_manager.py
  7. +0
    -1
      fastNLP/core/callbacks/has_monitor_callback.py
  8. +17
    -1
      fastNLP/core/collators/__init__.py
  9. +6
    -2
      fastNLP/core/collators/collator.py
  10. +30
    -0
      fastNLP/core/collators/padders/__init__.py
  11. +21
    -34
      fastNLP/core/collators/padders/get_padder.py
  12. +8
    -1
      fastNLP/core/collators/padders/numpy_padder.py
  13. +30
    -10
      fastNLP/core/collators/padders/paddle_padder.py
  14. +36
    -1
      fastNLP/core/collators/padders/raw_padder.py
  15. +14
    -3
      fastNLP/core/collators/padders/torch_padder.py
  16. +5
    -1
      fastNLP/core/collators/padders/utils.py
  17. +0
    -2
      fastNLP/core/controllers/__init__.py
  18. +59
    -9
      fastNLP/core/controllers/trainer.py
  19. +26
    -27
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  20. +33
    -29
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  21. +31
    -30
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  22. +16
    -0
      fastNLP/core/dataloaders/utils.py
  23. +0
    -0
      fastNLP/core/dataloaders/utils/__init__.py
  24. +1
    -10
      fastNLP/core/dataset/dataset.py
  25. +2
    -2
      fastNLP/core/drivers/paddle_driver/fleet.py
  26. +3
    -3
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  27. +2
    -2
      fastNLP/core/drivers/paddle_driver/single_device.py
  28. +2
    -2
      fastNLP/core/drivers/torch_driver/single_device.py
  29. +3
    -3
      fastNLP/core/drivers/torch_driver/torch_driver.py
  30. +25
    -0
      fastNLP/core/log/print.py
  31. +3
    -2
      fastNLP/core/samplers/__init__.py
  32. +209
    -5
      fastNLP/core/samplers/reproducible_batch_sampler.py
  33. +3
    -2
      fastNLP/core/samplers/reproducible_sampler.py
  34. +2
    -2
      fastNLP/core/utils/__init__.py
  35. +2
    -2
      fastNLP/core/utils/dummy_class.py
  36. +5
    -0
      fastNLP/core/utils/paddle_utils.py
  37. +1
    -20
      fastNLP/core/utils/utils.py
  38. +35
    -4
      fastNLP/io/data_bundle.py
  39. +208
    -0
      tests/core/callbacks/test_callback_event.py
  40. +0
    -157
      tests/core/callbacks/test_callback_events.py
  41. +13
    -9
      tests/core/callbacks/test_checkpoint_callback_torch.py
  42. +6
    -4
      tests/core/callbacks/test_more_evaluate_callback.py
  43. +24
    -1
      tests/core/collators/padders/test_get_padder.py
  44. +18
    -18
      tests/core/collators/padders/test_paddle_padder.py
  45. +1
    -2
      tests/core/collators/padders/test_raw_padder.py
  46. +287
    -75
      tests/core/collators/test_collator.py
  47. +0
    -293
      tests/core/collators/test_new_collator.py
  48. +219
    -10
      tests/core/controllers/test_trainer_event_trigger.py
  49. +5
    -5
      tests/core/controllers/test_trainer_other_things.py
  50. +6
    -4
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  51. +9
    -5
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  52. +5
    -7
      tests/core/controllers/utils/test_utils.py
  53. +2
    -1
      tests/core/drivers/jittor_driver/test_single_device.py
  54. +15
    -15
      tests/core/drivers/paddle_driver/test_single_device.py
  55. +3
    -3
      tests/core/drivers/paddle_driver/test_utils.py
  56. +149
    -118
      tests/core/drivers/torch_driver/test_ddp.py
  57. +8
    -8
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  58. +12
    -12
      tests/core/drivers/torch_driver/test_single_device.py
  59. +1
    -1
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py
  60. +3
    -3
      tests/core/drivers/torch_driver/test_utils.py
  61. +9
    -4
      tests/core/metrics/test_accuracy_torch.py
  62. +8
    -3
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  63. +9
    -4
      tests/core/metrics/test_span_f1_rec_acc_torch.py
  64. +4
    -2
      tests/core/metrics/utils.py
  65. +433
    -153
      tests/core/samplers/test_reproducible_batch_sampler.py
  66. +141
    -0
      tests/core/samplers/test_reproducible_batch_sampler_torch.py
  67. +1
    -0
      tests/core/utils/test_cache_results.py
  68. +2
    -1
      tests/envs/test_set_backend.py
  69. +3
    -1
      tests/helpers/callbacks/helper_callbacks_torch.py
  70. +52
    -4
      tests/helpers/datasets/normal_data.py
  71. +6
    -2
      tests/helpers/datasets/torch_data.py
  72. +10
    -5
      tests/helpers/models/torch_model.py
  73. +6
    -0
      tests/pytest.ini

+ 59
- 10
fastNLP/core/__init__.py View File

@@ -1,4 +1,53 @@
__all__ = [ __all__ = [
# callbacks
'Callback',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",
'MoreEvaluateCallback',
"TorchWarmupCallback",
"TorchGradClipCallback",

# collators
'Collator',
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",

# controllers
'Loop',
'EvaluateBatchLoop',
'TrainBatchLoop',
'Evaluator',
'Trainer',

# dataloaders TODO 需要把 mix_dataloader 的搞定

# dataset
'DataSet',
'FieldArray',
'Instance',
'ApplyResultException',

# drivers
"TorchSingleDriver", "TorchSingleDriver",
"TorchDDPDriver", "TorchDDPDriver",
"PaddleSingleDriver", "PaddleSingleDriver",
@@ -7,16 +56,16 @@ __all__ = [
"JittorMPIDriver", "JittorMPIDriver",
"TorchPaddleDriver", "TorchPaddleDriver",


"paddle_to",
"get_paddle_gpu_str",
"get_paddle_device_id",
"paddle_move_data_to_device",
"torch_paddle_move_data_to_device",
]
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import;
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.evaluator import Evaluator
from fastNLP.core.dataloaders.torch_dataloader import *
# log
"logger"


#
]
from .callbacks import *
from .collators import *
from .controllers import *
from .dataloaders import *
from .dataset import *
from .drivers import * from .drivers import *
from .log import *
from .utils import * from .utils import *

+ 2
- 3
fastNLP/core/callbacks/__init__.py View File

@@ -1,7 +1,6 @@
__all__ = [ __all__ = [
'Callback', 'Callback',
'Events',
'EventsList',
'Event',
'Filter', 'Filter',
'CallbackManager', 'CallbackManager',
'CheckpointCallback', 'CheckpointCallback',
@@ -20,7 +19,7 @@ __all__ = [




from .callback import Callback from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_event import Event, Filter
from .callback_manager import CallbackManager from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback


+ 27
- 29
fastNLP/core/callbacks/callback.py View File

@@ -3,10 +3,9 @@ __all__ = [
'Callback', 'Callback',
] ]


from typing import Union, Callable, Dict, Optional, Any
from typing import Callable, Dict, Optional


from .callback_events import Events, EventsList, Filter
from fastNLP.core.callbacks.callback_events import _SingleEventState
from .callback_event import Event, Filter




class Callback: class Callback:
@@ -14,32 +13,35 @@ class Callback:
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类;
callback 调用时机顺序大概如下 callback 调用时机顺序大概如下
Trainer.__init__(): Trainer.__init__():
on_after_trainer_initialized()
on_after_trainer_initialized(trainer, driver)
Trainer.run(): Trainer.run():
if num_eval_sanity_batch>0: if num_eval_sanity_batch>0:
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch
on_sanity_check_end()
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try: try:
on_train_begin()
on_train_begin(trainer)
while cur_epoch_idx < n_epochs: while cur_epoch_idx < n_epochs:
on_train_epoch_begin()
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch: while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin()
on_fetch_data_end()
on_train_batch_begin()
on_before_backward()
on_after_backward()
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_train_batch_end()
on_train_epoch_end()
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException: except BaseException:
self.on_exception()
self.on_exception(trainer, exception)
finally: finally:
on_train_end()
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定
的时间调用。
""" """


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
@@ -294,18 +296,14 @@ class _CallbackWrapper(Callback):
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的
这一个 callback 函数; 这一个 callback 函数;
""" """
def __init__(self, event: Union[Events, EventsList], fn: Callable):
def __init__(self, event: Event, fn: Callable):
r""" r"""
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList';
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;
:param fn: 用户定制的 callback 函数; :param fn: 用户定制的 callback 函数;
""" """


self.fn = fn self.fn = fn
if isinstance(event, EventsList):
for each_event in event:
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn)
setattr(self, each_event.value, _filter(fn))
elif isinstance(event, _SingleEventState):
if isinstance(event, Event):
_filter = Filter(event.every, event.once, event.filter_fn) _filter = Filter(event.every, event.once, event.filter_fn)
setattr(self, event.value, _filter(fn)) setattr(self, event.value, _filter(fn))




+ 499
- 0
fastNLP/core/callbacks/callback_event.py View File

@@ -0,0 +1,499 @@
from typing import Optional, Callable, Dict
from functools import wraps


__all__ = [
'Event',
'Filter'
]


def check_legality(fn):
@wraps(fn)
def wrap(every=None, once=None, filter_fn=None):
if (every is None) and (once is None) and (filter_fn is None):
every = 1

if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")

if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument filter_fn should be a callable")

if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")

if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")
return fn(every=every, once=once, filter_fn=filter_fn)
return wrap


class Event:
every: Optional[int]
once: Optional[int]

def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None,
filter_fn: Optional[Callable] = None):
"""
请勿直接使用本对象,而是通过调用 Event.on_after_trainer_initialized() 等方式调用。

:param value: Trainer 的 callback 时机。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
"""
self.every = every
self.once = once
self.filter_fn = filter_fn
self.value = value

def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once,
self.filter_fn)
@staticmethod
@check_legality
def on_after_trainer_initialized(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_trainer_initialized 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_sanity_check_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_sanity_check_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_sanity_check_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_sanity_check_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_epoch_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_epoch_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_epoch_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_epoch_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_fetch_data_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_fetch_data_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_fetch_data_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_fetch_data_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_batch_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_batch_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_batch_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_batch_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_exception(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_exception 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_save_model(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_save_model 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_model(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_model 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_save_checkpoint(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_save_checkpoint 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_checkpoint(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_checkpoint 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_checkpoint(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_checkpoint 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_backward(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_backward 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_backward(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_backward 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_optimizers_step(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_optimizers_step 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_optimizers_step(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_optimizers_step 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_zero_grad(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_zero_grad 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_zero_grad(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_zero_grad 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_evaluate_begin(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_evaluate_begin 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_evaluate_end(every=None, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_evaluate_end 时

以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn)


class Filter:
def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率;

:param every: 表示一个函数隔多少次运行一次;
:param once: 表示一个函数只运行一次;
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态;
"""
# check legality
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn)
if (every is None) and (once is None) and (filter_fn is None):
every = 1
# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0

if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn

def __call__(self, fn: Callable):

@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1

# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)

wrapper.__fastNLP_filter__ = self
return wrapper

def every_filter(self, *args):
return self.num_called % self._every == 0

def once_filter(self, *args):
return self.num_called == self._once

def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态;
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}

def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态;

:param state: 通过 `Filter.state_dict` 函数保存的状态元组;
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]








+ 0
- 206
fastNLP/core/callbacks/callback_events.py View File

@@ -1,206 +0,0 @@
from enum import Enum, unique
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict
from types import DynamicClassAttribute
from functools import wraps


__all__ = [
'Events',
'EventsList',
'Filter'
]


class _SingleEventState:
every: Optional[int]
once: Optional[int]

def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None,
filter_fn: Optional[Callable] = None, name: Optional[str] = None):

# 具体的检测参数对错的逻辑放在具体的 Filter 里;
if every is None and once is None and filter_fn is None:
self.every = 1
self.once = None
self.filter_fn = None
else:
self.every = every
self.once = once
self.filter_fn = filter_fn

if not hasattr(self, "_value_"):
self._value_ = value

if not hasattr(self, "_name_") and name is not None:
self._name_ = name

# copied to be compatible to enum
@DynamicClassAttribute
def name(self) -> str:
"""The name of the Enum member."""
return self._name_

@DynamicClassAttribute
def value(self) -> str:
"""The value of the Enum member."""
return self._value_

def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
return _SingleEventState(self.value, every, once, filter_fn, self.name)

def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once,
self.filter_fn)

def __eq__(self, other) -> bool:
if isinstance(other, _SingleEventState):
return self.name == other.name
elif isinstance(other, str):
return self.name == other
else:
raise NotImplemented

def __hash__(self):
return hash(self._name_)

def __or__(self, other) -> "EventsList":
return EventsList() | self | other


class EventEnum(_SingleEventState, Enum):
pass

@unique
class Events(EventEnum):
on_after_trainer_initialized = "on_after_trainer_initialized"
on_sanity_check_begin = "on_sanity_check_begin"
on_sanity_check_end = "on_sanity_check_end"
on_train_begin = "on_train_begin"
on_train_end = "on_train_end"
on_train_epoch_begin = "on_train_epoch_begin"
on_train_epoch_end = "on_train_epoch_end"
on_fetch_data_begin = "on_fetch_data_begin"
on_fetch_data_end = "on_fetch_data_end"
on_train_batch_begin = "on_train_batch_begin"
on_train_batch_end = "on_train_batch_end"
on_exception = "on_exception"
on_save_model = "on_save_model"
on_load_model = "on_load_model"
on_save_checkpoint = "on_save_checkpoint"
on_load_checkpoint = "on_load_checkpoint"
on_before_backward = "on_before_backward"
on_after_backward = "on_after_backward"
on_before_optimizers_step = "on_before_optimizers_step"
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_evaluate_begin = "on_evaluate_begin"
on_evaluate_end = "on_evaluate_end"


class EventsList:
"""Collection of events stacked by operator `__or__`.
"""

def __init__(self) -> None:
self._events = [] # type: List[Union[Events, _SingleEventState]]

def _append(self, event: Union[Events, _SingleEventState]) -> None:
if not isinstance(event, (Events, _SingleEventState)):
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)

def __getitem__(self, item: int) -> Union[Events, _SingleEventState]:
return self._events[item]

def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]:
return iter(self._events)

def __len__(self) -> int:
return len(self._events)

def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList":
self._append(event=other)
return self


class Filter:
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率;

:param every: 表示一个函数隔多少次运行一次;
:param once: 表示一个函数只在第多少次时运行一次;
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态;
"""
if (every is None) and (once is None) and (filter_fn is None):
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.")

if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")

if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument event_filter should be a callable")

if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")

if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")

# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0

if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn

def __call__(self, fn: Callable):

@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1

# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)

wrapper.__fastNLP_filter__ = self
return wrapper

def every_filter(self, *args):
return self.num_called % self._every == 0

def once_filter(self, *args):
return self.num_called == self._once

def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态;
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}

def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态;

:param state: 通过 `Filter.state_dict` 函数保存的状态元组;
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]








+ 8
- 7
fastNLP/core/callbacks/callback_manager.py View File

@@ -6,7 +6,7 @@ __all__ = [
'CallbackManager' 'CallbackManager'
] ]


from .callback_events import Events
from .callback_event import Event
from .callback import Callback from .callback import Callback
from fastNLP.core.log import logger from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback from .progress_callback import ProgressCallback, choose_progress_callback
@@ -110,7 +110,7 @@ class CallbackManager:
def initialize_class_callbacks(self): def initialize_class_callbacks(self):
r""" r"""
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是
一个个 callback 时机,也就是 `Events` 的类别;
一个个 callback 时机,也就是 `Event` 的类别;
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中;


:param callbacks: :param callbacks:
@@ -127,11 +127,12 @@ class CallbackManager:
:param callback: 一个具体的 callback 实例; :param callback: 一个具体的 callback 实例;
""" """
self.all_callbacks.append(callback) self.all_callbacks.append(callback)
for name, member in Events.__members__.items():
_fn = getattr(callback, member.value)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)):
self.callback_fns[member.value].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)
for name, member in Event.__dict__.items():
if isinstance(member, staticmethod):
_fn = getattr(callback, name)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)):
self.callback_fns[name].append(_fn)
self.extract_callback_filter_state(callback.callback_name, _fn)


def extract_callback_filter_state(self, callback_name, callback_fn): def extract_callback_filter_state(self, callback_name, callback_fn):
r""" r"""


+ 0
- 1
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -161,7 +161,6 @@ class MonitorUtility:
return monitor_name return monitor_name





class HasMonitorCallback(MonitorUtility, Callback): class HasMonitorCallback(MonitorUtility, Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False): def __init__(self, monitor, larger_better, must_have_monitor=False):
""" """


+ 17
- 1
fastNLP/core/collators/__init__.py View File

@@ -1,4 +1,20 @@
__all__ = [ __all__ = [
'Collator'
'Collator',

'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",
] ]
from .collator import Collator from .collator import Collator
from .padders import *

+ 6
- 2
fastNLP/core/collators/collator.py View File

@@ -65,12 +65,16 @@ def _get_backend() -> str:
return catch_backend[0] return catch_backend[0]


# 方式 (2) # 方式 (2)
for backend in CHECK_BACKEND:
if backend in sys.modules:
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
return backend
for key, module in sys.modules.items(): for key, module in sys.modules.items():
catch_backend = _check_module(module) catch_backend = _check_module(module)
if catch_backend: if catch_backend:
break break
if len(catch_backend): if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0] return catch_backend[0]


return 'numpy' return 'numpy'
@@ -227,7 +231,7 @@ class Collator:
设置可以 pad 的 field 默认 pad 为什么类型的 tensor 设置可以 pad 的 field 默认 pad 为什么类型的 tensor


:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。
:return: :return:
""" """
assert backend in SUPPORTED_BACKENDS assert backend in SUPPORTED_BACKENDS


+ 30
- 0
fastNLP/core/collators/padders/__init__.py View File

@@ -0,0 +1,30 @@

__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",

"Padder",
"NullPadder",

"RawNumberPadder",
"RawSequencePadder",

'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',

"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",

"get_padded_numpy_array",
]


from .numpy_padder import *
from .padder import Padder, NullPadder
from .raw_padder import *
from .torch_padder import *
from .paddle_padder import *
from .utils import get_padded_numpy_array

+ 21
- 34
fastNLP/core/collators/padders/get_padder.py View File

@@ -1,8 +1,3 @@

from typing import Dict



from typing import Sequence, Any, Union, Dict from typing import Sequence, Any, Union, Dict
from abc import ABC from abc import ABC


@@ -12,7 +7,7 @@ from fastNLP.core.log import logger
from .padder import Padder, NullPadder from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from .exceptions import * from .exceptions import *


@@ -28,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
:param field_name: 方便报错的。 :param field_name: 方便报错的。
:return: :return:
""" """
assert len(batch_field)!=0, "Empty batch encountered."
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field))
if pad_val is None: if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
@@ -68,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NullPadder() return NullPadder()


# 再检查所有的元素 type 是否一致 # 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
try:
ele_dtypes = set([v[1] for v in catalog.values()])
except TypeError:
ele_dtypes = set([str(v[1]) for v in catalog.values()])
num_eletypes = len(ele_dtypes) num_eletypes = len(ele_dtypes)
if num_eletypes != 1: if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
@@ -80,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->


depth = depths.pop() depth = depths.pop()
shape_len = shape_lens.pop() shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()
ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了


# 需要由 padder 自己决定是否能够 pad 。 # 需要由 padder 自己决定是否能够 pad 。
try: try:
@@ -93,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")


if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw': if backend == 'raw':
@@ -103,14 +103,21 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")


if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
# 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行
if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)):
if backend == 'raw':
return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'torch': elif backend == 'torch':
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")


if shape_len != 0 and depth>1: if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report." msg = "Does not support pad tensor under nested list. If you need this, please report."
@@ -179,23 +186,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
return catalog return catalog




"""
from numbers import Number

issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)
"""




+ 8
- 1
fastNLP/core/collators/padders/numpy_padder.py View File

@@ -66,7 +66,7 @@ class NumpySequencePadder(Padder):
class NumpyTensorPadder(Padder): class NumpyTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
""" """
pad 类似于 [np.array([3, 4], np.array([1])] 的 field
pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。


:param pad_val: pad 的值是多少。 :param pad_val: pad 的值是多少。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
@@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder):


@staticmethod @staticmethod
def pad(batch_field, pad_val, dtype): def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], np.ndarray):
batch_field = [np.array(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field] shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype) array = np.full(max_shape, fill_value=pad_val, dtype=dtype)


+ 30
- 10
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype):




def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")


@@ -74,13 +74,20 @@ def _get_dtype(ele_dtype, dtype, class_name):
elif is_numpy_generic_class(ele_dtype): elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
else: else:
dtype == ele_dtype
dtype = ele_dtype


return dtype return dtype




class PaddleNumberPadder(Padder): class PaddleNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3])

:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
"""
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype) super().__init__(pad_val=pad_val, dtype=dtype)
@@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder):




class PaddleSequencePadder(Padder): class PaddleSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, ele_dtype=None, pad_val=0, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。

:param pad_val: pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype) super().__init__(pad_val=pad_val, dtype=dtype)


@@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder):




class PaddleTensorPadder(Padder): class PaddleTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
""" """
目前支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。


:param ele_dtype:
:param pad_val:
:param dtype:
:param pad_val: pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
""" """
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype) super().__init__(pad_val=pad_val, dtype=dtype)


@staticmethod @staticmethod
def pad(batch_field, pad_val, dtype): def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], paddle.Tensor):
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field] shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype): if isinstance(dtype, np.dtype):
@@ -174,6 +195,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
""" """
shapes = get_shape(batch_field) shapes = get_shape(batch_field)
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype) tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor return tensor

+ 36
- 1
fastNLP/core/collators/padders/raw_padder.py View File

@@ -1,4 +1,8 @@

__all__ = [
"RawNumberPadder",
"RawSequencePadder",
"RawTensorPadder"
]


from .padder import Padder from .padder import Padder
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number
@@ -63,3 +67,34 @@ class RawSequencePadder(Padder):
:return: :return:
""" """
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()


class RawTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。

:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
"""

:param batch_field:
:param pad_val:
:param dtype: 该参数无意义。
:return:
"""
try:
if not isinstance(batch_field[0], (list, tuple)):
batch_field = [field.tolist() for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), "
f"it must have tolist() method.")

return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

+ 14
- 3
fastNLP/core/collators/padders/torch_padder.py View File

@@ -1,4 +1,8 @@

__all__ = [
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder'
]
from inspect import isclass from inspect import isclass
import numpy as np import numpy as np


@@ -37,7 +41,7 @@ def is_torch_tensor(dtype):




def _get_dtype(ele_dtype, dtype, class_name): def _get_dtype(ele_dtype, dtype, class_name):
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")


@@ -97,7 +101,7 @@ class TorchSequencePadder(Padder):
class TorchTensorPadder(Padder): class TorchTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
""" """
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。


:param pad_val: 需要 pad 的值。 :param pad_val: 需要 pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。
@@ -108,6 +112,13 @@ class TorchTensorPadder(Padder):


@staticmethod @staticmethod
def pad(batch_field, pad_val, dtype): def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], torch.Tensor):
batch_field = [torch.tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field] shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)


+ 5
- 1
fastNLP/core/collators/padders/utils.py View File

@@ -1,6 +1,10 @@


__all__ = [
'get_padded_numpy_array'
]


from typing import Sequence, List from typing import Sequence, List
from numbers import Number
import re import re
from inspect import isclass from inspect import isclass




+ 0
- 2
fastNLP/core/controllers/__init__.py View File

@@ -2,8 +2,6 @@ __all__ = [
'Loop', 'Loop',
'EvaluateBatchLoop', 'EvaluateBatchLoop',
'TrainBatchLoop', 'TrainBatchLoop',
'State',
'TrainerState',
'Evaluator', 'Evaluator',
'Trainer', 'Trainer',
] ]


+ 59
- 9
fastNLP/core/controllers/trainer.py View File

@@ -17,10 +17,10 @@ from .utils import State, TrainerState
from .utils.utils import check_evaluate_every from .utils.utils import check_evaluate_every
from .evaluator import Evaluator from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks import Callback, CallbackManager
from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_manager import prepare_callbacks from fastNLP.core.callbacks.callback_manager import prepare_callbacks
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.callback_event import Event
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger):
raise e raise e
finally: finally:
self.on_train_end() self.on_train_end()
self.driver.barrier()


def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
@@ -399,7 +398,7 @@ class Trainer(TrainerEventTrigger):
if self.cur_epoch_idx % evaluate_every == 0: if self.cur_epoch_idx % evaluate_every == 0:
self.run_evaluate() self.run_evaluate()


def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
def add_callback_fn(self, event: Event, fn: Callable):
r""" r"""
在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数;
这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数;
@@ -407,19 +406,69 @@ class Trainer(TrainerEventTrigger):
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
:param fn: 具体的 callback 函数; :param fn: 具体的 callback 函数;
""" """
if not isinstance(event, (_SingleEventState, EventsList)):
raise ValueError("parameter event should only be `Events` or `EventsList` type.")
if not isinstance(event, Event):
raise ValueError("parameter event should only be `Event` type.")


_custom_callback = _CallbackWrapper(event, fn) _custom_callback = _CallbackWrapper(event, fn)
self.callback_manager.dissect_one_callback(_custom_callback) self.callback_manager.dissect_one_callback(_custom_callback)


@classmethod @classmethod
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None):
def on(cls, event: Event, marker: Optional[str] = None):
r""" r"""
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如
Trainer.__init__():
on_after_trainer_initialized(trainer, driver)
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try:
on_train_begin(trainer)
while cur_epoch_idx < n_epochs:
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException:
self.on_exception(trainer, exception)
finally:
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
特定的时间调用。

Example::
from fastNLP import Event
@Trainer.on(Event.on_save_model())
def do_something_1(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行。

@Trainer.on(Event.on_save_model(once=True))
def do_something_2(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行,但只执行一次。

@Trainer.on(Event.on_train_batch_begin(every=2))
def do_something_3(trainer, batch, indices):
# do something
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。

注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前;


:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含
特定的参数,可以通过上述说明查阅。
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时,
表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer
实例使用; 实例使用;
@@ -427,9 +476,9 @@ class Trainer(TrainerEventTrigger):
""" """


def wrapper(fn: Callable) -> Callable: def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
_check_valid_parameters_number(fn, callback_fn_args) _check_valid_parameters_number(fn, callback_fn_args)
cls._custom_callbacks[marker].append((event, fn))
return fn return fn


return wrapper return wrapper
@@ -441,6 +490,7 @@ class Trainer(TrainerEventTrigger):
""" """
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
_own_callbacks.extend(self._custom_callbacks[None]) _own_callbacks.extend(self._custom_callbacks[None])
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
self._custom_callbacks[None] = [] self._custom_callbacks[None] = []
if self.marker is not None: if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0: if len(self._custom_callbacks[self.marker]) == 0:


+ 26
- 27
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -14,7 +14,7 @@ else:
from fastNLP.core.dataset import DataSet as Dataset from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.dataset import DataSet as FDataSet




@@ -107,33 +107,33 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 return (len(self.dataset) - 1) // self.dataset.batch_size + 1


def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "JittorDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "JittorDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -146,18 +146,17 @@ class JittorDataLoader:
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]: def get_batch_indices(self) -> List[int]:
""" """
获取当前数据的idx
获取当前 batch 的 idx


:return: :return:
""" """
return self.cur_batch_indices return self.cur_batch_indices



def prepare_jittor_dataloader(): def prepare_jittor_dataloader():
... ...

+ 33
- 29
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -15,8 +15,9 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader


from fastNLP.core.collators.collator import Collator from fastNLP.core.collators.collator import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler




class _PaddleDataset(Dataset): class _PaddleDataset(Dataset):
@@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader):
if not isinstance(dataset, _PaddleDataset): if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset) dataset = _PaddleDataset(dataset)


if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last)

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler, return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
@@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader):
if isinstance(dataset.dataset, FDataSet): if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle") self._collate_fn.set_backend(backend="paddle")
# if collate_fn is not None:
# self._collate_fn.add_collator(collate_fn)
else: else:
self._collate_fn = Collator(backend="paddle") self._collate_fn = Collator(backend="paddle")


@@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader):
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield data yield data


def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "PaddleDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "PaddleDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader):
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]: def get_batch_indices(self) -> List[int]:
""" """
获取当前数据的idx
获取当前 batch 的 idx


:return: :return:
""" """
@@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader):




def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
return_list: bool = True,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
train_batch_size: int = 1, shuffle: bool = False, train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
num_workers: int = 0, use_buffer_reader: bool = True, num_workers: int = 0, use_buffer_reader: bool = True,


+ 31
- 30
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,14 +3,14 @@ __all__ = [
'prepare_torch_dataloader' 'prepare_torch_dataloader'
] ]


from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler from torch.utils.data import DataLoader, Sampler
@@ -76,6 +76,10 @@ class TorchDataLoader(DataLoader):
if not isinstance(dataset, _FDataSet): if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset) dataset = _FDataSet(dataset)


if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -87,9 +91,6 @@ class TorchDataLoader(DataLoader):
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch") self._collate_fn.set_backend(backend="torch")
# if collate_fn is not None and collate_fn is not default_collate:
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来
# self._collate_fn.add_collator(collate_fn)
else: else:
self._collate_fn = Collator(backend="torch") self._collate_fn = Collator(backend="torch")
else: else:
@@ -112,31 +113,32 @@ class TorchDataLoader(DataLoader):
yield data yield data


def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "TorchDataLoader":
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "TorchDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -149,24 +151,23 @@ class TorchDataLoader(DataLoader):
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]: def get_batch_indices(self) -> List[int]:
""" """
获取当前数据的idx
获取当前 batch 的 idx


:return: :return:
""" """
return self.cur_batch_indices return self.cur_batch_indices





def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False, pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None, timeout: float = 0, worker_init_fn: Optional[Callable] = None,


+ 16
- 0
fastNLP/core/dataloaders/utils.py View File

@@ -0,0 +1,16 @@
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper

+ 0
- 0
fastNLP/core/dataloaders/utils/__init__.py View File


+ 1
- 10
fastNLP/core/dataset/dataset.py View File

@@ -770,17 +770,8 @@ class DataSet:
df = self.to_pandas() df = self.to_pandas()
return df.to_csv(path, encoding="utf-8") return df.to_csv(path, encoding="utf-8")


def set_ignore(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self.collator.set_ignore(*field_names)

@property @property
def collator(self):
def collator(self) -> Collator:
if self._collator is None: if self._collator is None:
self._collator = Collator() self._collator = Collator()
return self._collator return self._collator

+ 2
- 2
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.utils import (
rank_zero_rm rank_zero_rm
) )
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler, ReproducibleSampler,
ReproducibleBatchSampler, ReproducibleBatchSampler,
RandomSampler, RandomSampler,
@@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):


return self.model, model.forward return self.model, model.forward


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
reproducible: bool = False): reproducible: bool = False):
r""" r"""
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。


+ 3
- 3
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.log import logger
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
ReproducibleBatchSampler, ReproducibleBatchSampler,
ReproducibleSampler, ReproducibleSampler,
RandomBatchSampler,
ReproduceBatchSampler,
RandomSampler, RandomSampler,
) )


@@ -345,7 +345,7 @@ class PaddleDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.") "`ReproducibleSampler`.")
else: else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size, batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
@@ -476,7 +476,7 @@ class PaddleDriver(Driver):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"): elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler res.sampler = batch_sampler.sampler


+ 2
- 2
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.utils import (
from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
ReproducibleBatchSampler, ReproducibleBatchSampler,
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler, ReproducibleSampler,
RandomSampler, RandomSampler,
re_instantiate_sampler, re_instantiate_sampler,
@@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else: else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler, batch_sampler=args.batch_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=args.drop_last drop_last=args.drop_last


+ 2
- 2
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else: else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler, batch_sampler=args.batch_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=args.drop_last drop_last=args.drop_last


+ 3
- 3
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler




class TorchDriver(Driver): class TorchDriver(Driver):
@@ -293,7 +293,7 @@ class TorchDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.") "`ReproducibleSampler`.")
else: else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size, batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
@@ -407,7 +407,7 @@ class TorchDriver(Driver):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"): elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler res.sampler = batch_sampler.sampler


+ 25
- 0
fastNLP/core/log/print.py View File

@@ -0,0 +1,25 @@
__all__ = [
'print'
]

from .logger import logger


def print(*args, sep=' ', end='\n', file=None, flush=False):
"""
用来重定向 print 函数至 logger.info 的函数。

Example:
from fastNLP import print

print("This is a test") # 等价于调用了 logger.info("This is a test")

:param args: 需要打印的内容
:param sep: 存在多个输入时,使用的间隔。
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。
:param file: 该参数无意义。
:param flush: 该参数无意义。
:return:
"""
line = sep.join(args)
logger.info(line)

+ 3
- 2
fastNLP/core/samplers/__init__.py View File

@@ -14,9 +14,10 @@ __all__ = [
"UnrepeatedSortedSampler", "UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler", "UnrepeatedSequentialSampler",


"RandomBatchSampler",
"ReproduceBatchSampler",
"BucketedBatchSampler", "BucketedBatchSampler",
"ReproducibleBatchSampler", "ReproducibleBatchSampler",
"RandomBatchSampler",


"re_instantiate_sampler" "re_instantiate_sampler"
] ]
@@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler



+ 209
- 5
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -1,5 +1,6 @@
__all__ = [ __all__ = [
'BucketedBatchSampler', 'BucketedBatchSampler',
"ReproduceBatchSampler",
"RandomBatchSampler" "RandomBatchSampler"
] ]


@@ -7,7 +8,6 @@ import math
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union, List from typing import Dict, Union, List
from itertools import chain from itertools import chain
import os


import numpy as np import numpy as np


@@ -54,13 +54,12 @@ class ReproducibleBatchSampler:
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")




class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
class ReproduceBatchSampler(ReproducibleBatchSampler):
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
""" """
可以使得 batch_sampler 对象状态恢复的 wrapper 。 可以使得 batch_sampler 对象状态恢复的 wrapper 。


:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
:param batch_size: 每个 batch 的大小是多少。 :param batch_size: 每个 batch 的大小是多少。
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
@@ -143,7 +142,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False self.need_reinitialize = False


def set_distributed(self, num_replicas, rank, pad=True): def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")


def set_epoch(self, epoch): def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@@ -158,6 +157,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size




class RandomBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
"""
随机分 batch 的 batch_sampler 。

:param dataset: 实现了 __len__ 方法的数据容器。
:param batch_size: 每个 batch 的大小
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__()

self.dataset = dataset

self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed

self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量

# 多卡的相关的参数
self.num_replicas = kwargs.get("num_replicas", 1)
self.rank = kwargs.get("rank", 0)
self.epoch = kwargs.get("epoch", -1)
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;

# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)

# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)

def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
assert num_replicas > 0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0 <= rank < num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
self.pad = pad

return self

def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True

indices = list(range(len(self.dataset)))

if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_indices = indices[_i:len(indices):self.old_num_replicas]
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
indices = list(chain(*batches))
indices = indices[self.num_consumed_samples:]
# 取出这个 rank ,
indices = indices[self.rank:len(indices):self.num_replicas]
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
batches = list(map(list, batches))
else:
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
_num_batches = len(indices) // self.batch_size
if _num_batches == 0:
batches = [indices]
else:
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)

assert sum(map(len, batches)) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
self.epoch -= 1

def batchify(self, indices, batch_size, seed):
"""
将 indices 分为 batches

:param sorted_indices: List[int]
:param batch_size: int
:param seed: int
:return: List[List[int]]
"""
# 实际的 bucket 大小
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
num_samples = 0
batches = []
while num_samples<len(indices):
batches.append(indices[num_samples:num_samples+batch_size])
num_samples += batch_size
return batches

def set_epoch(self, epoch):
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

@property
def total_size(self):
"""
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
大于或者小于len(dataset)

:return:
"""
return self.num_consumed_samples + self.num_replicas*self.num_left_samples

@property
def num_left_samples(self):
"""
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。

:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))

def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据

:return:
"""
num_sampler_per_rank = self.total_size//self.num_replicas
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
return num_batches

def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}

return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_replicas = states['num_replicas']


class BucketedBatchSampler(ReproducibleBatchSampler): class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):


+ 3
- 2
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet


class ReproducibleSampler: class ReproducibleSampler:
""" """
可复现的 Sampler 对象。

注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。


@@ -54,13 +56,12 @@ class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
""" """



:param dataset: 实现了 __len__ 方法的数据容器 :param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 :param shuffle: 是否在每次 iterate 的时候打乱顺序。
:param seed: 随机数种子。 :param seed: 随机数种子。
:param kwargs: 用户不需要使用,fastNLP 内部使用 :param kwargs: 用户不需要使用,fastNLP 内部使用
""" """
super(RandomSampler, self).__init__()
self.dataset = dataset self.dataset = dataset
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed self.seed = seed


+ 2
- 2
fastNLP/core/utils/__init__.py View File

@@ -21,7 +21,6 @@ __all__ = [
'nullcontext', 'nullcontext',
'pretty_table_printer', 'pretty_table_printer',
'Option', 'Option',
'indice_collate_wrapper',
'deprecated', 'deprecated',
'seq_len_to_mask', 'seq_len_to_mask',
'rank_zero_rm', 'rank_zero_rm',
@@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
from ..dataloaders.utils import indice_collate_wrapper





+ 2
- 2
fastNLP/core/utils/dummy_class.py View File

@@ -1,5 +1,5 @@
import functools import functools


class DummyClass: class DummyClass:
def __call__(self, *args, **kwargs):
return
def __init__(self, *args, **kwargs):
pass

+ 5
- 0
fastNLP/core/utils/paddle_utils.py View File

@@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]):
else: else:
return data.cuda(get_paddle_device_id(device)) return data.cuda(get_paddle_device_id(device))



def get_paddle_gpu_str(device: Union[str, int]): def get_paddle_gpu_str(device: Union[str, int]):
""" """
获得 `gpu:x` 类型的设备名 获得 `gpu:x` 类型的设备名
@@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]):
return device.replace("cuda", "gpu") return device.replace("cuda", "gpu")
return f"gpu:{device}" return f"gpu:{device}"



def get_paddle_device_id(device: Union[str, int]): def get_paddle_device_id(device: Union[str, int]):
""" """
获得 gpu 的设备id 获得 gpu 的设备id
@@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,


return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to)



def is_in_paddle_dist(): def is_in_paddle_dist():
""" """
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断
""" """
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)



def is_in_fnlp_paddle_dist(): def is_in_fnlp_paddle_dist():
""" """
判断是否处于 FastNLP 拉起的分布式进程中 判断是否处于 FastNLP 拉起的分布式进程中
""" """
return FASTNLP_DISTRIBUTED_CHECK in os.environ return FASTNLP_DISTRIBUTED_CHECK in os.environ



def is_in_paddle_launch_dist(): def is_in_paddle_launch_dist():
""" """
判断是否处于 launch 启动的分布式进程中 判断是否处于 launch 启动的分布式进程中


+ 1
- 20
fastNLP/core/utils/utils.py View File

@@ -6,7 +6,7 @@ import warnings
from dataclasses import is_dataclass from dataclasses import is_dataclass
from copy import deepcopy from copy import deepcopy
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional from typing import Tuple, Optional
from time import sleep from time import sleep


@@ -35,7 +35,6 @@ __all__ = [
'nullcontext', 'nullcontext',
'pretty_table_printer', 'pretty_table_printer',
'Option', 'Option',
'indice_collate_wrapper',
'deprecated', 'deprecated',
'seq_len_to_mask', 'seq_len_to_mask',
'rank_zero_rm', 'rank_zero_rm',
@@ -513,24 +512,6 @@ class Option(dict):
self.update(state) self.update(state)




def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper


_emitted_deprecation_warnings = set() _emitted_deprecation_warnings = set()






+ 35
- 4
fastNLP/io/data_bundle.py View File

@@ -332,13 +332,44 @@ class DataBundle:
show_progress_bar=show_progress_bar, progress_desc=progress_desc) show_progress_bar=show_progress_bar, progress_desc=progress_desc)
return res return res


def set_pad_val(self, *field_names, val=0) -> None:
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: self
"""
for _, ds in self.iter_datasets(): for _, ds in self.iter_datasets():
ds.set_pad_val(*field_names, val=val)
ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend,
pad_fn=pad_fn)
return self


def set_input(self, *field_names) -> None:
def set_ignore(self, *field_names) -> "DataBundle":
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: self
"""
for _, ds in self.iter_datasets(): for _, ds in self.iter_datasets():
ds.set_input(*field_names)
ds.collator.set_ignore(*field_names)
return self


def __repr__(self) -> str: def __repr__(self) -> str:
_str = '' _str = ''


+ 208
- 0
tests/core/callbacks/test_callback_event.py View File

@@ -0,0 +1,208 @@
import pytest
from functools import reduce

from fastNLP.core.callbacks.callback_event import Event, Filter



class TestFilter:
def test_every_filter(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w-1 for w in range(10, 101, 10)]

# every = 1
@Filter(every=1)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))

def test_once_filter(self):
# once = 10
@Filter(once=10)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]


def test_extract_filter_from_fn(self):
@Filter(every=10)
def _fn(data):
return data

_filter_num_called = []
_filter_num_executed = []
for i in range(100):
cu_res = _fn(i)
_filter = _fn.__fastNLP_filter__
_filter_num_called.append(_filter.num_called)
_filter_num_executed.append(_filter.num_executed)
assert _filter_num_called == list(range(1, 101))
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]

def _fn(data):
return data
assert not hasattr(_fn, "__fastNLP_filter__")

def test_filter_state_dict(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data

_res = []
for i in range(50):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 51, 10)]

# 保存状态
state = _fn.__fastNLP_filter__.state_dict()
# 加载状态
_fn.__fastNLP_filter__.load_state_dict(state)

_res = []
for i in range(50, 100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(60, 101, 10)]


@pytest.mark.torch
def test_filter_fn_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]


class TestCallbackEvents:
def test_every(self):

# 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))

event_state = Event.on_train_begin(every=10)
@Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 101, 10)]

def test_once(self):
event_state = Event.on_train_begin(once=10)

@Filter(once=event_state.once)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]


@pytest.mark.torch
def test_callback_events_torch():
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

event_state = Event.on_train_begin(filter_fn=filter_fn)

@Filter(filter_fn=event_state.filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]










+ 0
- 157
tests/core/callbacks/test_callback_events.py View File

@@ -1,157 +0,0 @@
import pytest
from functools import reduce

from fastNLP.core.callbacks.callback_events import Events, Filter


class TestFilter:

def test_params_check(self):
# 顺利通过
_filter1 = Filter(every=10)
_filter2 = Filter(once=10)
_filter3 = Filter(filter_fn=lambda: None)

# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter4 = Filter()
exec_msg = e.value.args[0]
assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter."

# 触发 ValueError
with pytest.raises(ValueError) as e:
_filter5 = Filter(every=10, once=10)
exec_msg = e.value.args[0]
assert exec_msg == "These three values should be only set one."

# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter6 = Filter(every="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument every should be integer and greater than zero"

# 触发 TypeError
with pytest.raises(ValueError) as e:
_filter7 = Filter(once="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument once should be integer and positive"

# 触发 TypeError
with pytest.raises(TypeError) as e:
_filter7 = Filter(filter_fn="heihei")
exec_msg = e.value.args[0]
assert exec_msg == "Argument event_filter should be a callable"

def test_every_filter(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w-1 for w in range(10, 101, 10)]

# every = 1
@Filter(every=1)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == list(range(100))

def test_once_filter(self):
# once = 10
@Filter(once=10)
def _fn(data):
return data

_res = []
for i in range(100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [9]

def test_filter_fn(self):
from torch.optim import SGD
from torch.utils.data import DataLoader
from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification

model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
optimizer = SGD(model.parameters(), lr=0.0001)
dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
dataloader = DataLoader(dataset=dataset, batch_size=4)

trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
def filter_fn(filter, trainer):
if trainer.__heihei_test__ == 10:
return True
return False

@Filter(filter_fn=filter_fn)
def _fn(trainer, data):
return data

_res = []
for i in range(100):
trainer.__heihei_test__ = i
cu_res = _fn(trainer, i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [10]

def test_extract_filter_from_fn(self):
@Filter(every=10)
def _fn(data):
return data

_filter_num_called = []
_filter_num_executed = []
for i in range(100):
cu_res = _fn(i)
_filter = _fn.__fastNLP_filter__
_filter_num_called.append(_filter.num_called)
_filter_num_executed.append(_filter.num_executed)
assert _filter_num_called == list(range(1, 101))
assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]

def _fn(data):
return data
assert not hasattr(_fn, "__fastNLP_filter__")

def test_filter_state_dict(self):
# every = 10
@Filter(every=10)
def _fn(data):
return data

_res = []
for i in range(50):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(10, 51, 10)]

# 保存状态
state = _fn.__fastNLP_filter__.state_dict()
# 加载状态
_fn.__fastNLP_filter__.load_state_dict(state)

_res = []
for i in range(50, 100):
cu_res = _fn(i)
if cu_res is not None:
_res.append(cu_res)
assert _res == [w - 1 for w in range(60, 101, 10)]



+ 13
- 9
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -2,9 +2,6 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
from pathlib import Path from pathlib import Path
import re import re
import time import time
@@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.log import logger from fastNLP.core.log import logger


from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:
@@ -216,9 +218,9 @@ def test_model_checkpoint_callback_2(
path = Path.cwd().joinpath("test_model_checkpoint") path = Path.cwd().joinpath("test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True) path.mkdir(exist_ok=True, parents=True)


from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event


@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer): def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError raise NotImplementedError
@@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2(


if version == 0: if version == 0:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2(
topk=None, topk=None,
last=False, last=False,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]
elif version == 1: elif version == 1:
callbacks = [ callbacks = [
TrainerCheckpointCallback(
CheckpointCallback(
monitor="acc", monitor="acc",
folder=path, folder=path,
every_n_epochs=None, every_n_epochs=None,
@@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2(
topk=1, topk=1,
last=True, last=True,
on_exception=None, on_exception=None,
model_save_fn=model_save_fn
model_save_fn=model_save_fn,
save_object="trainer"
) )
] ]




+ 6
- 4
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -12,9 +12,7 @@ import os
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist

from pathlib import Path from pathlib import Path
import re import re


@@ -29,7 +27,11 @@ from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback from fastNLP.core.callbacks import MoreEvaluateCallback

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist


@dataclass @dataclass
class ArgMaxDatasetConfig: class ArgMaxDatasetConfig:


+ 24
- 1
tests/core/collators/padders/test_get_padder.py View File

@@ -17,12 +17,13 @@ def test_get_element_shape_dtype():
@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.paddle @pytest.mark.paddle
@pytest.mark.jittor
def test_get_padder_run(backend): def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch': if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch") pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle': if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle") pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
if not _NEED_IMPORT_JITTOR and backend == 'jittor':
pytest.skip("No jittor") pytest.skip("No jittor")
batch_field = [1, 2, 3] batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
@@ -66,6 +67,13 @@ def test_raw_padder():
pad_batch = padder(batch_field) pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2) assert np.shape(pad_batch) == (3, 3, 2)


batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, list)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12



def test_numpy_padder(): def test_numpy_padder():
backend = 'numpy' backend = 'numpy'
@@ -140,3 +148,18 @@ def test_torch_padder():
with pytest.raises(InconsistencyError): with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


# 可以是 numpy.ndarray
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12

# 测试 to numpy
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, np.ndarray)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

+ 18
- 18
tests/core/collators/padders/test_paddle_padder.py View File

@@ -1,7 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest


from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder
from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE


@@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE:




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleNumberPadder:
class TestPaddleNumberPadder:
def test_run(self): def test_run(self):
padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3] a = [1, 2, 3]
t_a = padder(a) t_a = padder(a)
assert isinstance(t_a, paddle.Tensor) assert isinstance(t_a, paddle.Tensor)
@@ -20,9 +20,9 @@ class TestpaddleNumberPadder:




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleSequencePadder:
class TestPaddleSequencePadder:
def test_run(self): def test_run(self):
padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]] a = [[1, 2, 3], [3]]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -32,20 +32,20 @@ class TestpaddleSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1] assert (a == b).sum().item() == shape[0]*shape[1]


def test_dtype_check(self): def test_dtype_check(self):
padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]]) a = padder([[1], [2, 322]])
# assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)




@pytest.mark.paddle @pytest.mark.paddle
class TestpaddleTensorPadder:
class TestPaddleTensorPadder:
def test_run(self): def test_run(self):
padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1)
a = [paddle.zeros((3,)), paddle.zeros((2,))] a = [paddle.zeros((3,)), paddle.zeros((2,))]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -74,7 +74,7 @@ class TestpaddleTensorPadder:
[[0, -1], [-1, -1], [-1, -1]]]) [[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1)
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -85,7 +85,7 @@ class TestpaddleTensorPadder:
]) ])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)]
a = padder(a) a = padder(a)
shape = a.shape shape = a.shape
@@ -96,11 +96,11 @@ class TestpaddleTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]


def test_dtype_check(self): def test_dtype_check(self):
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1)
padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1)


def test_v1(self): def test_v1(self):
print(paddle.zeros((3, )).dtype) print(paddle.zeros((3, )).dtype)

+ 1
- 2
tests/core/collators/padders/test_raw_padder.py View File

@@ -23,7 +23,6 @@ class TestRawSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1] assert (a == b).sum().item() == shape[0]*shape[1]


def test_dtype_check(self): def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError): with pytest.raises(DtypeError):
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)

+ 287
- 75
tests/core/collators/test_collator.py View File

@@ -1,81 +1,293 @@

import numpy as np
import pytest import pytest


from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)




class TestCollator: class TestCollator:


@pytest.mark.parametrize('as_numpy', [True, False])
def test_auto_collator(self, as_numpy):
"""
测试auto_collator的auto_pad功能

:param as_numpy:
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=as_numpy)
collator.set_input('x', 'y')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
results = []
for bucket in bucket_data:
res = collator(bucket)
assert res['x'].shape == (40, 5)
assert res['y'].shape == (40,)
results.append(res)

def test_auto_collator_v1(self):
"""
测试auto_collator的set_pad_val和set_pad_val功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
collator.set_input('x')
collator.set_pad_val('x', val=-1)
collator.set_as_numpy(True)
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = collator(bucket)
print(res)

def test_multicollator(self):
"""
测试multicollator功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
multi_collator = _MultiCollator(collator)
multi_collator.set_as_numpy(as_numpy=True)
multi_collator.set_pad_val('x', val=-1)
multi_collator.set_input('x')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = multi_collator(bucket)
print(res)
@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))

def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)







+ 0
- 293
tests/core/collators/test_new_collator.py View File

@@ -1,293 +0,0 @@

import numpy as np
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)


class TestCollator:

@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))

def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)







+ 219
- 10
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -1,17 +1,20 @@
import pytest import pytest
from typing import Any from typing import Any
from dataclasses import dataclass from dataclasses import dataclass
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist



from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
import torch.distributed as dist




@dataclass @dataclass
@@ -62,12 +65,11 @@ def model_and_optimizers():


return trainer_params return trainer_params


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch
@magic_argv_env_context @magic_argv_env_context
def test_trainer_event_trigger(
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
@@ -97,8 +99,215 @@ def test_trainer_event_trigger(
if dist.is_initialized(): if dist.is_initialized():
dist.destroy_process_group() dist.destroy_process_group()


for name, member in Events.__members__.items():
assert member.value in output[0]
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]

@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):

@Trainer.on(Event.on_after_trainer_initialized())
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Event.on_sanity_check_begin())
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Event.on_sanity_check_end())
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Event.on_train_begin())
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Event.on_train_end())
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Event.on_train_epoch_begin())
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Event.on_train_epoch_end())
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Event.on_fetch_data_begin())
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Event.on_fetch_data_end())
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Event.on_train_batch_begin())
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Event.on_train_batch_end())
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Event.on_exception())
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Event.on_before_backward())
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Event.on_after_backward())
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Event.on_before_optimizers_step())
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Event.on_after_optimizers_step())
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Event.on_before_zero_grad())
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Event.on_after_zero_grad())
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Event.on_evaluate_begin())
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Event.on_evaluate_end())
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_3(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):
import re

once_message_1 = "This message should be typed 1 times."
once_message_2 = "test_filter_fn"
once_message_3 = "once message 3"
twice_message = "twice message hei hei"

@Trainer.on(Event.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)

@Trainer.on(Event.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)

@Trainer.on(Event.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)

def filter_fn(filter, trainer):
if trainer.cur_epoch_idx == 1:
return True
else:
return False

@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)

with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()


once_pattern_1 = re.compile(once_message_1)
once_pattern_2 = re.compile(once_message_2)
once_pattern_3 = re.compile(once_message_3)
twice_pattern = re.compile(twice_message)

once_res_1 = once_pattern_1.findall(output[0])
assert len(once_res_1) == 1
once_res_2 = once_pattern_2.findall(output[0])
assert len(once_res_2) == 1
once_res_3 = once_pattern_3.findall(output[0])
assert len(once_res_3) == 1
twice_res = twice_pattern.findall(output[0])
assert len(twice_res) == 2




















+ 5
- 5
tests/core/controllers/test_trainer_other_things.py View File

@@ -1,22 +1,22 @@
import pytest import pytest


from fastNLP.core.controllers.trainer import Trainer from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks import Events
from fastNLP.core.callbacks import Event
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context




@magic_argv_env_context @magic_argv_env_context
def test_trainer_torch_without_evaluator(): def test_trainer_torch_without_evaluator():
@Trainer.on(Events.on_train_epoch_begin(every=10))
@Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things")
def fn1(trainer): def fn1(trainer):
pass pass


@Trainer.on(Events.on_train_batch_begin(every=10))
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn2(trainer, batch, indices): def fn2(trainer, batch, indices):
pass pass


with pytest.raises(AssertionError):
@Trainer.on(Events.on_train_batch_begin(every=10))
with pytest.raises(BaseException):
@Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things")
def fn3(trainer, batch): def fn3(trainer, batch):
pass pass




+ 6
- 4
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -2,9 +2,7 @@
注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成;
""" """
import pytest import pytest
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from torchmetrics import Accuracy from torchmetrics import Accuracy
@@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.distributed as dist


@dataclass @dataclass
class NormalClassificationTrainTorchConfig: class NormalClassificationTrainTorchConfig:


+ 9
- 5
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -2,9 +2,7 @@ import os.path
import subprocess import subprocess
import sys import sys
import pytest import pytest
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader

from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from pathlib import Path from pathlib import Path
@@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
from tests.helpers.utils import magic_argv_env_context, Capturing from tests.helpers.utils import magic_argv_env_context, Capturing
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.optim import SGD
from torch.utils.data import DataLoader




@dataclass @dataclass
@@ -257,9 +260,9 @@ def test_trainer_on_exception(
cur_rank, cur_rank,
n_epochs=2, n_epochs=2,
): ):
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event


@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer): def raise_exception(trainer):
if trainer.driver.get_local_rank() == cur_rank: if trainer.driver.get_local_rank() == cur_rank:
raise NotImplementedError raise NotImplementedError
@@ -286,6 +289,7 @@ def test_trainer_on_exception(
dist.destroy_process_group() dist.destroy_process_group()




@pytest.mark.torch
@pytest.mark.parametrize("version", [0, 1, 2, 3]) @pytest.mark.parametrize("version", [0, 1, 2, 3])
@magic_argv_env_context @magic_argv_env_context
def test_torch_distributed_launch_1(version): def test_torch_distributed_launch_1(version):


+ 5
- 7
tests/core/controllers/utils/test_utils.py View File

@@ -1,7 +1,7 @@
from functools import reduce from functools import reduce


from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
from tests.helpers.datasets.normal_data import NormalIterator
from tests.helpers.datasets.normal_data import NormalSampler




class Test_WrapDataLoader: class Test_WrapDataLoader:
@@ -9,9 +9,9 @@ class Test_WrapDataLoader:
def test_normal_generator(self): def test_normal_generator(self):
all_sanity_batches = [4, 20, 100] all_sanity_batches = [4, 20, 100]
for sanity_batches in all_sanity_batches: for sanity_batches in all_sanity_batches:
data = NormalIterator(num_of_data=1000)
data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper(dataloader=data))
dataloader = iter(wrapper)
mark = 0 mark = 0
while True: while True:
try: try:
@@ -32,8 +32,7 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000) dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
dataloader = iter(dataloader)
dataloader = iter(wrapper)
all_supposed_running_data_num = 0 all_supposed_running_data_num = 0
while True: while True:
try: try:
@@ -55,6 +54,5 @@ class Test_WrapDataLoader:
dataset = TorchNormalDataset(num_of_data=1000) dataset = TorchNormalDataset(num_of_data=1000)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches)
dataloader = wrapper(dataloader)
length.append(len(dataloader))
length.append(len(wrapper))
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])

+ 2
- 1
tests/core/drivers/jittor_driver/test_single_device.py View File

@@ -15,7 +15,7 @@ else:






class Model (Module):
class Model(Module):
def __init__ (self): def __init__ (self):
super (Model, self).__init__() super (Model, self).__init__()
self.conv1 = nn.Conv (3, 32, 3, 1) # no padding self.conv1 = nn.Conv (3, 32, 3, 1) # no padding
@@ -45,6 +45,7 @@ class Model (Module):
return x return x


@pytest.mark.jittor @pytest.mark.jittor
@pytest.mark.skip("Skip jittor tests now.")
class TestSingleDevice: class TestSingleDevice:


def test_on_gpu_without_fp16(self): def test_on_gpu_without_fp16(self):


+ 15
- 15
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path from pathlib import Path


from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
dataset = PaddleNormalDataset() dataset = PaddleNormalDataset()
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size, batch_size,
drop_last, drop_last,
@@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
res = PaddleSingleDriver.get_dataloader_args(dataloader) res = PaddleSingleDriver.get_dataloader_args(dataloader)


assert isinstance(res.dataset, PaddleNormalDataset) assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle: if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler) assert isinstance(res.sampler, paddle.io.RandomSampler)
else: else:
@@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else: else:
# 此时会替换 batch_sampler # 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist assert replaced_loader.batch_sampler is dist


self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
""" """
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4, batch_size=4,
drop_last=False, drop_last=False,
@@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches: if idx >= num_consumed_batches:
break break
already_seen_idx.update(batch) already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict() sampler_states = replaced_loader.batch_sampler.state_dict()
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()


# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader # 重新改造 dataloader
new_loader = DataLoader( new_loader = DataLoader(
dataset=replaced_loader.dataset, dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size, batch_size=batch_size,
drop_last=False, drop_last=False,
@@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
dataset = PaddleRandomMaxDataset(40, 10) dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
) )
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")


@@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 更改 batch_size # 更改 batch_size
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
) )
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")
@@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换 # 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4




+ 3
- 3
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
replace_batch_sampler, replace_batch_sampler,
replace_sampler, replace_sampler,
) )
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
@@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
def test_replace_batch_sampler(): def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10) dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)


replaced_loader = replace_batch_sampler(dataloader, batch_sampler) replaced_loader = replace_batch_sampler(dataloader, batch_sampler)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset) assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16 assert replaced_loader.batch_sampler.batch_size == 16


+ 149
- 118
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler


import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, BatchSampler

def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"):
def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"):
torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device] device = [torch.device(i) for i in device]
@@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
# #
############################################################################ ############################################################################


@pytest.mark.torch
@magic_argv_env_context
def test_multi_drivers():
"""
测试使用了多个 TorchDDPDriver 的情况。
"""
generate_driver(10, 10)
generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()

if dist.is_initialized():
dist.destroy_process_group()

@pytest.mark.torch @pytest.mark.torch
class TestDDPDriverFunction: class TestDDPDriverFunction:
""" """
测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
""" """


@classmethod
def setup_class(cls):
cls.driver = generate_driver(10, 10)

@magic_argv_env_context @magic_argv_env_context
def test_multi_drivers(self):
def test_simple_functions(self):
""" """
测试使用了多个 TorchDDPDriver 的情况。
简单测试多个函数
""" """
driver2 = generate_driver(20, 10)
with pytest.raises(RuntimeError):
# 设备设置不同,应该报错
driver3 = generate_driver(20, 3, device=[0,1,2])
assert False
dist.barrier()
driver = generate_driver(10, 10)


@magic_argv_env_context
def test_move_data_to_device(self):
""" """
这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中
就不重复测试了
测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在
tests/core/utils/test_torch_utils.py中,就不重复测试了
""" """
self.driver.move_data_to_device(torch.rand((32, 64)))

driver.move_data_to_device(torch.rand((32, 64)))
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_is_distributed(self):
""" """
测试 is_distributed 函数 测试 is_distributed 函数
""" """
assert self.driver.is_distributed() == True
assert driver.is_distributed() == True
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_get_no_sync_context(self):
""" """
测试 get_no_sync_context 函数 测试 get_no_sync_context 函数
""" """
res = self.driver.get_model_no_sync_context()
res = driver.get_model_no_sync_context()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_is_global_zero(self):
""" """
测试 is_global_zero 函数 测试 is_global_zero 函数
""" """
self.driver.is_global_zero()
driver.is_global_zero()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_unwrap_model(self):
""" """
测试 unwrap_model 函数 测试 unwrap_model 函数
""" """
self.driver.unwrap_model()
driver.unwrap_model()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_get_local_rank(self):
""" """
测试 get_local_rank 函数 测试 get_local_rank 函数
""" """
self.driver.get_local_rank()
driver.get_local_rank()
dist.barrier() dist.barrier()


@magic_argv_env_context
def test_all_gather(self):
""" """
测试 all_gather 函数 测试 all_gather 函数
详细的测试在 test_dist_utils.py 中完成 详细的测试在 test_dist_utils.py 中完成
""" """
obj = { obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
} }
obj_list = self.driver.all_gather(obj, group=None)
obj_list = driver.all_gather(obj, group=None)
for i, res in enumerate(obj_list): for i, res in enumerate(obj_list):
assert res["rank"] == i assert res["rank"] == i


@magic_argv_env_context
@pytest.mark.parametrize("src_rank", ([0, 1]))
def test_broadcast_object(self, src_rank):
""" """
测试 broadcast_object 函数 测试 broadcast_object 函数
详细的函数在 test_dist_utils.py 中完成 详细的函数在 test_dist_utils.py 中完成
""" """
if self.driver.global_rank == src_rank:
if driver.global_rank == 0:
obj = { obj = {
"rank": self.driver.global_rank
"rank": driver.global_rank
} }
else: else:
obj = None obj = None
res = self.driver.broadcast_object(obj, src=src_rank)
assert res["rank"] == src_rank
res = driver.broadcast_object(obj, src=0)
assert res["rank"] == 0

if dist.is_initialized():
dist.destroy_process_group()


############################################################################ ############################################################################
# #
@@ -187,7 +180,6 @@ class TestSetDistReproDataloader:
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.device = [0, 1] cls.device = [0, 1]
cls.driver = generate_driver(10, 10, device=cls.device)


def setup_method(self): def setup_method(self):
self.dataset = TorchNormalDataset(40) self.dataset = TorchNormalDataset(40)
@@ -204,17 +196,20 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现
此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler is batch_sampler assert replaced_loader.batch_sampler is batch_sampler
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -223,9 +218,10 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现
此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle)
sampler = RandomSampler(self.dataset, shuffle=shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -234,9 +230,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.sampler is sampler
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()
""" """
传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler`
@@ -251,15 +249,17 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现
当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
# 应当抛出 RuntimeError # 应当抛出 RuntimeError
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
# @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle):
""" """
@@ -268,21 +268,24 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler
和原 dataloader 相同 和原 dataloader 相同
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True pad=True
) )
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
self.check_distributed_sampler(dataloader.batch_sampler) self.check_distributed_sampler(dataloader.batch_sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)


dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -292,12 +295,13 @@ class TestSetDistReproDataloader:
此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其
batch_sampler.sampler 和原 dataloader 相同 batch_sampler.sampler 和原 dataloader 相同
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
) )
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -307,9 +311,11 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.batch_size == 4
assert replaced_loader.batch_sampler.drop_last == False assert replaced_loader.batch_sampler.drop_last == False
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -318,11 +324,14 @@ class TestSetDistReproDataloader:
测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现
此时直接返回原来的 dataloader,不做任何处理。 此时直接返回原来的 dataloader,不做任何处理。
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False)


assert replaced_loader is dataloader assert replaced_loader is dataloader
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


""" """
传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -337,12 +346,13 @@ class TestSetDistReproDataloader:
的表现 的表现
此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle)
) )
dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler)
@@ -351,6 +361,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler) self.check_distributed_sampler(replaced_loader.batch_sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -361,8 +373,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
@@ -372,6 +385,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -381,8 +396,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -392,6 +408,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


""" """
传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数
@@ -407,8 +425,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -418,6 +437,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.batch_sampler.sampler.shuffle == shuffle assert replaced_loader.batch_sampler.sampler.shuffle == shuffle
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -427,8 +448,9 @@ class TestSetDistReproDataloader:
的表现 的表现
此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -439,6 +461,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@@ -448,8 +472,9 @@ class TestSetDistReproDataloader:
此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关
的属性 的属性
""" """
driver = generate_driver(10, 10, device=self.device)
dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)
replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler, BatchSampler)
@@ -459,6 +484,8 @@ class TestSetDistReproDataloader:
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) self.check_distributed_sampler(replaced_loader.batch_sampler.sampler)
dist.barrier() dist.barrier()
if dist.is_initialized():
dist.destroy_process_group()


def check_distributed_sampler(self, sampler): def check_distributed_sampler(self, sampler):
""" """
@@ -469,7 +496,7 @@ class TestSetDistReproDataloader:
if not isinstance(sampler, UnrepeatedSampler): if not isinstance(sampler, UnrepeatedSampler):
assert sampler.pad == True assert sampler.pad == True


def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle):
def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle):
""" """
测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确
""" """
@@ -501,8 +528,8 @@ class TestSetDistReproDataloader:
drop_last=False, drop_last=False,
) )
new_loader.batch_sampler.set_distributed( new_loader.batch_sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank,
num_replicas=driver.world_size,
rank=driver.global_rank,
pad=True pad=True
) )
new_loader.batch_sampler.load_state_dict(sampler_states) new_loader.batch_sampler.load_state_dict(sampler_states)
@@ -512,8 +539,8 @@ class TestSetDistReproDataloader:
# 重新构造 dataloader # 重新构造 dataloader
new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False)
new_loader.batch_sampler.sampler.set_distributed( new_loader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver.world_size,
rank=self.driver.global_rank
num_replicas=driver.world_size,
rank=driver.global_rank
) )
new_loader.batch_sampler.sampler.load_state_dict(sampler_states) new_loader.batch_sampler.sampler.load_state_dict(sampler_states)
for idx, batch in enumerate(new_loader): for idx, batch in enumerate(new_loader):
@@ -534,11 +561,6 @@ class TestSaveLoad:
测试多卡情况下 save 和 load 相关函数的表现 测试多卡情况下 save 和 load 相关函数的表现
""" """


@classmethod
def setup_class(cls):
# 不在这里 setup 的话会报错
cls.driver = generate_driver(10, 10)

def setup_method(self): def setup_method(self):
self.dataset = TorchArgMaxDataset(10, 20) self.dataset = TorchArgMaxDataset(10, 20)


@@ -552,26 +574,26 @@ class TestSaveLoad:
path = "model" path = "model"


dataloader = DataLoader(self.dataset, batch_size=2) dataloader = DataLoader(self.dataset, batch_size=2)
self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10)
driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10)


self.driver1.save_model(path, only_state_dict)
driver1.save_model(path, only_state_dict)


# 同步 # 同步
dist.barrier() dist.barrier()
self.driver2.load_model(path, only_state_dict)
driver2.load_model(path, only_state_dict)


for idx, batch in enumerate(dataloader): for idx, batch in enumerate(dataloader):
batch = self.driver1.move_data_to_device(batch)
res1 = self.driver1.model(
batch = driver1.move_data_to_device(batch)
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -580,6 +602,9 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False]))
@@ -593,7 +618,7 @@ class TestSaveLoad:
path = "model.ckp" path = "model.ckp"
num_replicas = len(device) num_replicas = len(device)


self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \
generate_driver(10, 10, device=device, fp16=False) generate_driver(10, 10, device=device, fp16=False)
dataloader = dataloader_with_bucketedbatchsampler( dataloader = dataloader_with_bucketedbatchsampler(
self.dataset, self.dataset,
@@ -603,8 +628,8 @@ class TestSaveLoad:
drop_last=False drop_last=False
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True pad=True
) )
num_consumed_batches = 2 num_consumed_batches = 2
@@ -623,7 +648,7 @@ class TestSaveLoad:
# 保存状态 # 保存状态
sampler_states = dataloader.batch_sampler.state_dict() sampler_states = dataloader.batch_sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches} save_states = {"num_consumed_batches": num_consumed_batches}
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
# 加载 # 加载
# 更改 batch_size # 更改 batch_size
dataloader = dataloader_with_bucketedbatchsampler( dataloader = dataloader_with_bucketedbatchsampler(
@@ -634,11 +659,11 @@ class TestSaveLoad:
drop_last=False drop_last=False
) )
dataloader.batch_sampler.set_distributed( dataloader.batch_sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True pad=True
) )
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")
# 1. 检查 optimizer 的状态 # 1. 检查 optimizer 的状态
# TODO optimizer 的 state_dict 总是为空 # TODO optimizer 的 state_dict 总是为空
@@ -652,7 +677,7 @@ class TestSaveLoad:


# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
@@ -664,16 +689,16 @@ class TestSaveLoad:


left_x_batches.update(batch["x"]) left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"]) left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -686,6 +711,9 @@ class TestSaveLoad:
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


if dist.is_initialized():
dist.destroy_process_group()

@magic_argv_env_context @magic_argv_env_context
@pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("only_state_dict", ([True, False]))
@pytest.mark.parametrize("fp16", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False]))
@@ -700,13 +728,13 @@ class TestSaveLoad:


num_replicas = len(device) num_replicas = len(device)


self.driver1 = generate_driver(10, 10, device=device, fp16=fp16)
self.driver2 = generate_driver(10, 10, device=device, fp16=False)
driver1 = generate_driver(10, 10, device=device, fp16=fp16)
driver2 = generate_driver(10, 10, device=device, fp16=False)


dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver1.world_size,
rank=self.driver1.global_rank,
num_replicas=driver1.world_size,
rank=driver1.global_rank,
pad=True pad=True
) )
num_consumed_batches = 2 num_consumed_batches = 2
@@ -726,18 +754,18 @@ class TestSaveLoad:
sampler_states = dataloader.batch_sampler.sampler.state_dict() sampler_states = dataloader.batch_sampler.sampler.state_dict()
save_states = {"num_consumed_batches": num_consumed_batches} save_states = {"num_consumed_batches": num_consumed_batches}
if only_state_dict: if only_state_dict:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True)
else: else:
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))])
# 加载 # 加载
# 更改 batch_size # 更改 batch_size
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False)
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.driver2.world_size,
rank=self.driver2.global_rank,
num_replicas=driver2.world_size,
rank=driver2.global_rank,
pad=True pad=True
) )
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")


# 1. 检查 optimizer 的状态 # 1. 检查 optimizer 的状态
@@ -753,7 +781,7 @@ class TestSaveLoad:
assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"]
# 3. 检查 fp16 是否被加载 # 3. 检查 fp16 是否被加载
if fp16: if fp16:
assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler)
assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler)


# 4. 检查 model 的参数是否正确 # 4. 检查 model 的参数是否正确
# 5. 检查 batch_idx # 5. 检查 batch_idx
@@ -765,16 +793,16 @@ class TestSaveLoad:


left_x_batches.update(batch["x"]) left_x_batches.update(batch["x"])
left_y_batches.update(batch["y"]) left_y_batches.update(batch["y"])
res1 = self.driver1.model(
res1 = driver1.model(
batch, batch,
fastnlp_fn=self.driver1.model.module.model.evaluate_step,
fastnlp_fn=driver1.model.module.model.evaluate_step,
# Driver.model -> DataParallel.module -> _FleetWrappingModel.model # Driver.model -> DataParallel.module -> _FleetWrappingModel.model
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
res2 = self.driver2.model(
res2 = driver2.model(
batch, batch,
fastnlp_fn=self.driver2.model.module.model.evaluate_step,
fastnlp_fn=driver2.model.module.model.evaluate_step,
fastnlp_signature_fn=None, fastnlp_signature_fn=None,
wo_auto_param_call=False, wo_auto_param_call=False,
) )
@@ -786,4 +814,7 @@ class TestSaveLoad:
assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas


finally: finally:
rank_zero_rm(path)
rank_zero_rm(path)

if dist.is_initialized():
dist.destroy_process_group()

+ 8
- 8
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -2,12 +2,14 @@ import pytest


from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.utils import magic_argv_env_context from tests.helpers.utils import magic_argv_env_context

import torch

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import device as torchdevice
else:
from fastNLP.core.utils.dummy_class import DummyClass as torchdevice


@pytest.mark.torch @pytest.mark.torch
def test_incorrect_driver(): def test_incorrect_driver():
@@ -20,7 +22,7 @@ def test_incorrect_driver():
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
["cpu", "cuda:0", 0, torch.device("cuda:0")]
["cpu", "cuda:0", 0, torchdevice("cuda:0")]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
@@ -83,7 +85,6 @@ def test_get_ddp(driver, device):
("driver", "device"), ("driver", "device"),
[("torch_ddp", "cpu")] [("torch_ddp", "cpu")]
) )
@magic_argv_env_context
def test_get_ddp_cpu(driver, device): def test_get_ddp_cpu(driver, device):
""" """
测试试图在 cpu 上初始化分布式训练的情况 测试试图在 cpu 上初始化分布式训练的情况
@@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device):
@pytest.mark.torch @pytest.mark.torch
@pytest.mark.parametrize( @pytest.mark.parametrize(
"device", "device",
[-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1]
[-2, [0, 20, 3], [-2], 20]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"driver", "driver",
["torch", "torch_ddp"] ["torch", "torch_ddp"]
) )
@magic_argv_env_context
def test_device_out_of_range(driver, device): def test_device_out_of_range(driver, device):
""" """
测试传入的device超过范围的情况 测试传入的device超过范围的情况


+ 12
- 12
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path from pathlib import Path


from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset
@@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:


def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
""" """
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader
""" """
if shuffle: if shuffle:
sampler = torch.utils.data.RandomSampler(dataset) sampler = torch.utils.data.RandomSampler(dataset)
@@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler( BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last sampler, batch_size=batch_size, drop_last=drop_last
), ),
@@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
res = TorchSingleDriver.get_dataloader_args(dataloader) res = TorchSingleDriver.get_dataloader_args(dataloader)


assert isinstance(res.dataset, TorchNormalDataset) assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle: if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler) assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else: else:
@@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else: else:
# 此时会替换 batch_sampler # 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist assert replaced_loader.batch_sampler is dist


self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches: if idx >= num_consumed_batches:
break break
already_seen_idx.update(batch) already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict() sampler_states = replaced_loader.batch_sampler.state_dict()
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()


# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader # 重新改造 dataloader
@@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换 # 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4




+ 1
- 1
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -30,7 +30,7 @@ class SequenceDataSet:




def check_replace_sampler(driver): def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler
# reproducible 是 True 和 False # reproducible 是 True 和 False


# 需要 check 返回的 sampler 和 dataloader 都不同了 # 需要 check 返回的 sampler 和 dataloader 都不同了


+ 3
- 3
tests/core/drivers/torch_driver/test_utils.py View File

@@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler, replace_batch_sampler,
replace_sampler, replace_sampler,
) )
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler from torch.utils.data import DataLoader, BatchSampler


from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
def test_replace_batch_sampler(): def test_replace_batch_sampler():
dataset = TorchNormalDataset(10) dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)


replaced_loader = replace_batch_sampler(dataloader, batch_sampler) replaced_loader = replace_batch_sampler(dataloader, batch_sampler)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset) assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset) assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16 assert replaced_loader.batch_sampler.batch_size == 16


+ 9
- 4
tests/core/metrics/test_accuracy_torch.py View File

@@ -7,15 +7,20 @@ import copy
import socket import socket
import pytest import pytest
import numpy as np import numpy as np
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method

from sklearn.metrics import accuracy_score as sklearn_accuracy from sklearn.metrics import accuracy_score as sklearn_accuracy


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.metric import Metric
from .utils import find_free_network_port, setup_ddp, _assert_allclose from .utils import find_free_network_port, setup_ddp, _assert_allclose
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -26,7 +31,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class: Type[Metric], metric_class: Type[Metric],
metric_kwargs: Dict[str, Any], metric_kwargs: Dict[str, Any],


+ 8
- 3
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -2,18 +2,23 @@ from functools import partial
import copy import copy


import pytest import pytest
import torch
import numpy as np import numpy as np
from torch.multiprocessing import Pool, set_start_method


from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.metrics import ClassifyFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)




def _test(local_rank: int, world_size: int, device: torch.device,
def _test(local_rank: int, world_size: int, device: "torch.device",
dataset: DataSet, metric_class, metric_kwargs, metric_result): dataset: DataSet, metric_class, metric_kwargs, metric_result):
metric = metric_class(**metric_kwargs) metric = metric_class(**metric_kwargs)
# dataset 也类似(每个进程有自己的一个) # dataset 也类似(每个进程有自己的一个)


+ 9
- 4
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -5,16 +5,21 @@ import os, sys
import copy import copy
from functools import partial from functools import partial


import torch
import torch.distributed
import numpy as np import numpy as np
import socket import socket
from torch.multiprocessing import Pool, set_start_method
# from multiprocessing import Pool, set_start_method # from multiprocessing import Pool, set_start_method
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from .utils import find_free_network_port, setup_ddp from .utils import find_free_network_port, setup_ddp
if _NEED_IMPORT_TORCH:
import torch
import torch.distributed
from torch.multiprocessing import Pool, set_start_method
else:
from fastNLP.core.utils.dummy_class import DummyClass as set_start_method


set_start_method("spawn", force=True) set_start_method("spawn", force=True)


@@ -44,7 +49,7 @@ pool = None


def _test(local_rank: int, def _test(local_rank: int,
world_size: int, world_size: int,
device: torch.device,
device: "torch.device",
dataset: DataSet, dataset: DataSet,
metric_class, metric_class,
metric_kwargs, metric_kwargs,


+ 4
- 2
tests/core/metrics/utils.py View File

@@ -2,9 +2,11 @@ import os, sys
import socket import socket
from typing import Union from typing import Union


import torch
from torch import distributed
import numpy as np import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed




def setup_ddp(rank: int, world_size: int, master_port: int) -> None: def setup_ddp(rank: int, world_size: int, master_port: int) -> None:


+ 433
- 153
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -1,161 +1,131 @@
from array import array

import numpy as np import numpy as np
import pytest import pytest
from itertools import chain from itertools import chain
from copy import deepcopy from copy import deepcopy
from array import array

from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler


class TestReproducibleBatchSampler:
def test_1(self):
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;

reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)

forward_steps = 3
iterator = iter(reproduce_batch_sampler)
i = 0
while i < forward_steps:
next(iterator)
i += 1

# 保存状态;
state = reproduce_batch_sampler.state_dict()

assert state == {"index_list": array("I", list(range(100))),
"num_consumed_samples": forward_steps * 4,
"sampler_type": "ReproduceBatchSampler"}

# 重新生成一个 batchsampler 然后加载状态;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 16)), list(range(16, 20)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 改变 batchsize;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 19)), list(range(19, 26)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 26
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break


from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

#
# class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西
# def test_torch_dataloader_1(self):
# import torch
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# next(iter_dataloader)
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 改变 batch_size;
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 断点重训的第二轮是否是一个完整的 dataloader;
# # 先把断点重训所在的那一个 epoch 跑完;
# begin_idx = 27
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# # 开始新的一轮;
# begin_idx = 0
# iter_dataloader = iter(dataloader)
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# def test_torch_dataloader_2(self):
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# all_supposed_data = []
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist())
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 先把这一轮的数据过完;
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
# while True:
# try:
# all_supposed_data.extend(next(iter_dataloader).tolist())
# except StopIteration:
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass
def test_2(self):

# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
before_batch_size = 7
sampler = NormalSampler(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader))

# 1. 保存状态
state = reproduce_batch_sampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
sampler = NormalSampler(num_of_data=100, shuffle=True)
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

# 先把这一轮的数据过完;
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
all_supposed_data.extend(next(iter_dataloader))
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(reproduce_batch_sampler)
res = []
while True:
try:
res.extend(next(iter_dataloader))
except StopIteration:
break
assert res != all_supposed_data




class DatasetWithVaryLength: class DatasetWithVaryLength:
@@ -511,3 +481,313 @@ class TestBucketedBatchSampler:
already_seen_set.update(batch) already_seen_set.update(batch)


assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)


class TestRandomBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):

before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4

dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
already_generate_indices.update(batch)

# 1. 保存状态
state = re_batchsampler.state_dict()

# 2. 断点重训,继续训练
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)

# 改变 batch_size;
after_batch_size = 3
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]

for batch in re_batchsampler3:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

# 再 save ,不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()

for batch in re_batchsampler3: # consume all, 这样才能save
pass

already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False,需要最终是所有 sample
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
for batch in re_batchsampler4:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)

assert len(already_generate_indices) == len(dataset)

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):

# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0

assert len(set(lengths))==1, lengths # 每个进程的batch数量一致

# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()

# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)

# 测试替换卡的数量。
num_replica = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据

:return:
"""
batch_size = 6
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)

sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
shuffle=shuffle, drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)

# 测试保存之后再次保存
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break

states = sampler.state_dict()
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

+ 141
- 0
tests/core/samplers/test_reproducible_batch_sampler_torch.py View File

@@ -0,0 +1,141 @@
from array import array
import torch
from torch.utils.data import DataLoader

import pytest

from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset


@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:
def test_torch_dataloader_1(self):
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "ReproduceBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

iter_dataloader = iter(dataloader)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert res != all_supposed_data


+ 1
- 0
tests/core/utils/test_cache_results.py View File

@@ -3,6 +3,7 @@ import pytest
import subprocess import subprocess
from io import StringIO from io import StringIO
import sys import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..'))


from fastNLP.core.utils.cache_results import cache_results from fastNLP.core.utils.cache_results import cache_results
from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm


+ 2
- 1
tests/envs/test_set_backend.py View File

@@ -1,4 +1,5 @@
import os import os
import pytest


from fastNLP.envs.set_backend import dump_fastnlp_backend from fastNLP.envs.set_backend import dump_fastnlp_backend
from tests.helpers.utils import Capturing from tests.helpers.utils import Capturing
@@ -9,7 +10,7 @@ def test_dump_fastnlp_envs():
filepath = None filepath = None
try: try:
with Capturing() as output: with Capturing() as output:
dump_fastnlp_backend()
dump_fastnlp_backend(backend="torch")
filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json')
assert filepath in output[0] assert filepath in output[0]
assert os.path.exists(filepath) assert os.path.exists(filepath)


+ 3
- 1
tests/helpers/callbacks/helper_callbacks_torch.py View File

@@ -1,7 +1,9 @@
import torch
from copy import deepcopy from copy import deepcopy


from fastNLP.core.callbacks.callback import Callback from fastNLP.core.callbacks.callback import Callback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch




class RecordAccumulationStepsCallback_Torch(Callback): class RecordAccumulationStepsCallback_Torch(Callback):


+ 52
- 4
tests/helpers/datasets/normal_data.py View File

@@ -1,13 +1,25 @@
import numpy as np import numpy as np
import random




class NormalIterator:
def __init__(self, num_of_data=1000):
class NormalSampler:
def __init__(self, num_of_data=1000, shuffle=False):
self._num_of_data = num_of_data self._num_of_data = num_of_data
self._data = list(range(num_of_data)) self._data = list(range(num_of_data))
if shuffle:
random.shuffle(self._data)
self.shuffle = shuffle
self._index = 0 self._index = 0
self.need_reinitialize = False


def __iter__(self): def __iter__(self):
if self.need_reinitialize:
self._index = 0
if self.shuffle:
random.shuffle(self._data)
else:
self.need_reinitialize = True

return self return self


def __next__(self): def __next__(self):
@@ -15,12 +27,45 @@ class NormalIterator:
raise StopIteration raise StopIteration
_data = self._data[self._index] _data = self._data[self._index]
self._index += 1 self._index += 1
return self._data
return _data


def __len__(self): def __len__(self):
return self._num_of_data return self._num_of_data




class NormalBatchSampler:
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size


class RandomDataset: class RandomDataset:
def __init__(self, num_data=10): def __init__(self, num_data=10):
self.data = np.random.rand(num_data) self.data = np.random.rand(num_data)
@@ -29,4 +74,7 @@ class RandomDataset:
return len(self.data) return len(self.data)


def __getitem__(self, item): def __getitem__(self, item):
return self.data[item]
return self.data[item]




+ 6
- 2
tests/helpers/datasets/torch_data.py View File

@@ -1,7 +1,11 @@
import torch import torch
from functools import reduce from functools import reduce
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.utils.data.sampler import SequentialSampler, BatchSampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset




class TorchNormalDataset(Dataset): class TorchNormalDataset(Dataset):


+ 10
- 5
tests/helpers/models/torch_model.py View File

@@ -1,9 +1,14 @@
import torch
import torch.nn as nn
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch.nn import Module
import torch.nn as nn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module




# 1. 最为基础的分类模型 # 1. 最为基础的分类模型
class TorchNormalModel_Classification_1(nn.Module):
class TorchNormalModel_Classification_1(Module):
""" """
单独实现 train_step 和 evaluate_step; 单独实现 train_step 和 evaluate_step;
""" """
@@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module):
return {"preds": x, "target": y} return {"preds": x, "target": y}




class TorchNormalModel_Classification_2(nn.Module):
class TorchNormalModel_Classification_2(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
""" """
@@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y} return {"loss": loss, "preds": x, "target": y}




class TorchNormalModel_Classification_3(nn.Module):
class TorchNormalModel_Classification_3(Module):
""" """
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数; 关闭 auto_param_call,forward 只有一个 batch 参数;


+ 6
- 0
tests/pytest.ini View File

@@ -0,0 +1,6 @@
[pytest]
markers =
torch
paddle
jittor
torchpaddle

Loading…
Cancel
Save