Browse Source

修复了测试的冲突

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
5fad600ebc
25 changed files with 1073 additions and 712 deletions
  1. +58
    -10
      fastNLP/core/__init__.py
  2. +2
    -3
      fastNLP/core/callbacks/__init__.py
  3. +27
    -29
      fastNLP/core/callbacks/callback.py
  4. +489
    -0
      fastNLP/core/callbacks/callback_event.py
  5. +0
    -206
      fastNLP/core/callbacks/callback_events.py
  6. +3
    -3
      fastNLP/core/callbacks/callback_manager.py
  7. +17
    -1
      fastNLP/core/collators/__init__.py
  8. +30
    -0
      fastNLP/core/collators/padders/__init__.py
  9. +6
    -25
      fastNLP/core/collators/padders/get_padder.py
  10. +4
    -1
      fastNLP/core/collators/padders/raw_padder.py
  11. +5
    -1
      fastNLP/core/collators/padders/torch_padder.py
  12. +5
    -1
      fastNLP/core/collators/padders/utils.py
  13. +0
    -2
      fastNLP/core/controllers/__init__.py
  14. +58
    -8
      fastNLP/core/controllers/trainer.py
  15. +0
    -9
      fastNLP/core/dataset/dataset.py
  16. +0
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py
  17. +2
    -0
      fastNLP/core/samplers/reproducible_sampler.py
  18. +5
    -0
      fastNLP/core/utils/paddle_utils.py
  19. +35
    -4
      fastNLP/io/data_bundle.py
  20. +1
    -1
      tests/core/callbacks/test_callback_event.py
  21. +2
    -2
      tests/core/callbacks/test_checkpoint_callback_torch.py
  22. +287
    -75
      tests/core/collators/test_collator.py
  23. +0
    -293
      tests/core/collators/test_new_collator.py
  24. +35
    -35
      tests/core/controllers/test_trainer_event_trigger.py
  25. +2
    -2
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 58
- 10
fastNLP/core/__init__.py View File

@@ -1,4 +1,53 @@
__all__ = [
# callbacks
'Callback',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",
'MoreEvaluateCallback',
"TorchWarmupCallback",
"TorchGradClipCallback",

# collators
'Collator',
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",

# controllers
'Loop',
'EvaluateBatchLoop',
'TrainBatchLoop',
'Evaluator',
'Trainer',

# dataloaders TODO 需要把 mix_dataloader 的搞定

# dataset
'DataSet',
'FieldArray',
'Instance',
'ApplyResultException',

# drivers
"TorchSingleDriver",
"TorchDDPDriver",
"PaddleSingleDriver",
@@ -7,16 +56,15 @@ __all__ = [
"JittorMPIDriver",
"TorchPaddleDriver",

"paddle_to",
"get_paddle_gpu_str",
"get_paddle_device_id",
"paddle_move_data_to_device",
"torch_paddle_move_data_to_device",
]
# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import;
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.controllers.evaluator import Evaluator
from fastNLP.core.dataloaders.torch_dataloader import *
# log
"logger"

]
from .callbacks import *
from .collators import *
from .controllers import *
from .dataloaders import *
from .dataset import *
from .drivers import *
from .log import *
from .utils import *

+ 2
- 3
fastNLP/core/callbacks/__init__.py View File

@@ -1,7 +1,6 @@
__all__ = [
'Callback',
'Events',
'EventsList',
'Event',
'Filter',
'CallbackManager',
'CheckpointCallback',
@@ -20,7 +19,7 @@ __all__ = [


from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_event import Event, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback


+ 27
- 29
fastNLP/core/callbacks/callback.py View File

@@ -3,10 +3,9 @@ __all__ = [
'Callback',
]

from typing import Union, Callable, Dict, Optional, Any
from typing import Callable, Dict, Optional

from .callback_events import Events, EventsList, Filter
from fastNLP.core.callbacks.callback_events import _SingleEventState
from .callback_event import Event, Filter


class Callback:
@@ -14,32 +13,35 @@ class Callback:
实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类;
callback 调用时机顺序大概如下
Trainer.__init__():
on_after_trainer_initialized()
on_after_trainer_initialized(trainer, driver)
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin() # 如果设置了num_eval_sanity_batch
on_sanity_check_end()
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try:
on_train_begin()
on_train_begin(trainer)
while cur_epoch_idx < n_epochs:
on_train_epoch_begin()
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin()
on_fetch_data_end()
on_train_batch_begin()
on_before_backward()
on_after_backward()
on_before_zero_grad() # 实际调用受到 accumulation_steps 影响
on_after_zero_grad() # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响
on_train_batch_end()
on_train_epoch_end()
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException:
self.on_exception()
self.on_exception(trainer, exception)
finally:
on_train_end()
其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定
的时间调用。
"""

def on_after_trainer_initialized(self, trainer, driver):
@@ -294,18 +296,14 @@ class _CallbackWrapper(Callback):
对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的
这一个 callback 函数;
"""
def __init__(self, event: Union[Events, EventsList], fn: Callable):
def __init__(self, event: Event, fn: Callable):
r"""
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList';
:param event: 具体的 callback 时机,例如 'on_train_begin' 等;
:param fn: 用户定制的 callback 函数;
"""

self.fn = fn
if isinstance(event, EventsList):
for each_event in event:
_filter = Filter(each_event.every, each_event.once, each_event.filter_fn)
setattr(self, each_event.value, _filter(fn))
elif isinstance(event, _SingleEventState):
if isinstance(event, Event):
_filter = Filter(event.every, event.once, event.filter_fn)
setattr(self, event.value, _filter(fn))



+ 489
- 0
fastNLP/core/callbacks/callback_event.py View File

@@ -0,0 +1,489 @@
from typing import Optional, Callable, Dict
from functools import wraps


__all__ = [
'Event',
'Filter'
]


def check_legality(fn):
@wraps(fn)
def wrap(every=1, once=None, filter_fn=None):
if (every is None) and (once is None) and (filter_fn is None):
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.")

if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")

if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument event_filter should be a callable")

if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")

if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")
return fn(every=every, once=once, filter_fn=filter_fn)
return wrap


class Event:
every: Optional[int]
once: Optional[int]

def __init__(self, value: str, every: Optional[int] = 1, once: Optional[int] = False,
filter_fn: Optional[Callable] = None):

self.every = every
self.once = once
self.filter_fn = filter_fn
self.value = value

def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is:{3}>".format(self.value, self.every, self.once,
self.filter_fn)
@staticmethod
@check_legality
def on_after_trainer_initialized(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_trainer_initialized 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_sanity_check_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_sanity_check_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_sanity_check_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_sanity_check_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_epoch_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_epoch_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_epoch_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_epoch_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_fetch_data_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_fetch_data_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_fetch_data_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_fetch_data_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_batch_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_batch_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_train_batch_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_train_batch_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_exception(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_exception 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_save_model(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_save_model 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_model(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_model 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_save_checkpoint(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_save_checkpoint 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_checkpoint(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_checkpoint 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_load_checkpoint(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_load_checkpoint 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_backward(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_backward 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_backward(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_backward 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_optimizers_step(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_optimizers_step 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_optimizers_step(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_optimizers_step 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_before_zero_grad(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_before_zero_grad 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_after_zero_grad(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_after_zero_grad 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_evaluate_begin(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_evaluate_begin 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn)

@staticmethod
@check_legality
def on_evaluate_end(every=1, once=None, filter_fn=None):
"""
当 Trainer 运行到 on_evaluate_end 时

以下三个参数互斥,只能设置其中一个。
:param int every: 触发了多少次,才真正运行一次。
:param bool once: 是否只在第一次运行后就不再执行了。
:param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和
filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。
:return:
"""
return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn)


class Filter:
def __init__(self, every: Optional[int] = 1, once: Optional[bool] = False, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率;

:param every: 表示一个函数隔多少次运行一次;
:param once: 表示一个函数只运行一次;
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态;
"""
# check legality
check_legality(lambda *args,**kwargs:...)(every, once, filter_fn)
# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0

if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn

def __call__(self, fn: Callable):

@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1

# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)

wrapper.__fastNLP_filter__ = self
return wrapper

def every_filter(self, *args):
return self.num_called % self._every == 0

def once_filter(self, *args):
return self.num_called == self._once

def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态;
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}

def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态;

:param state: 通过 `Filter.state_dict` 函数保存的状态元组;
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]








+ 0
- 206
fastNLP/core/callbacks/callback_events.py View File

@@ -1,206 +0,0 @@
from enum import Enum, unique
from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict
from types import DynamicClassAttribute
from functools import wraps


__all__ = [
'Events',
'EventsList',
'Filter'
]


class _SingleEventState:
every: Optional[int]
once: Optional[int]

def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None,
filter_fn: Optional[Callable] = None, name: Optional[str] = None):

# 具体的检测参数对错的逻辑放在具体的 Filter 里;
if every is None and once is None and filter_fn is None:
self.every = 1
self.once = None
self.filter_fn = None
else:
self.every = every
self.once = once
self.filter_fn = filter_fn

if not hasattr(self, "_value_"):
self._value_ = value

if not hasattr(self, "_name_") and name is not None:
self._name_ = name

# copied to be compatible to enum
@DynamicClassAttribute
def name(self) -> str:
"""The name of the Enum member."""
return self._name_

@DynamicClassAttribute
def value(self) -> str:
"""The value of the Enum member."""
return self._value_

def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
return _SingleEventState(self.value, every, once, filter_fn, self.name)

def __str__(self):
return "<event={0}, every={1}, once={2}, filter fn is None:{3}>".format(self.name, self.every, self.once,
self.filter_fn)

def __eq__(self, other) -> bool:
if isinstance(other, _SingleEventState):
return self.name == other.name
elif isinstance(other, str):
return self.name == other
else:
raise NotImplemented

def __hash__(self):
return hash(self._name_)

def __or__(self, other) -> "EventsList":
return EventsList() | self | other


class EventEnum(_SingleEventState, Enum):
pass

@unique
class Events(EventEnum):
on_after_trainer_initialized = "on_after_trainer_initialized"
on_sanity_check_begin = "on_sanity_check_begin"
on_sanity_check_end = "on_sanity_check_end"
on_train_begin = "on_train_begin"
on_train_end = "on_train_end"
on_train_epoch_begin = "on_train_epoch_begin"
on_train_epoch_end = "on_train_epoch_end"
on_fetch_data_begin = "on_fetch_data_begin"
on_fetch_data_end = "on_fetch_data_end"
on_train_batch_begin = "on_train_batch_begin"
on_train_batch_end = "on_train_batch_end"
on_exception = "on_exception"
on_save_model = "on_save_model"
on_load_model = "on_load_model"
on_save_checkpoint = "on_save_checkpoint"
on_load_checkpoint = "on_load_checkpoint"
on_before_backward = "on_before_backward"
on_after_backward = "on_after_backward"
on_before_optimizers_step = "on_before_optimizers_step"
on_after_optimizers_step = "on_after_optimizers_step"
on_before_zero_grad = "on_before_zero_grad"
on_after_zero_grad = "on_after_zero_grad"
on_evaluate_begin = "on_evaluate_begin"
on_evaluate_end = "on_evaluate_end"


class EventsList:
"""Collection of events stacked by operator `__or__`.
"""

def __init__(self) -> None:
self._events = [] # type: List[Union[Events, _SingleEventState]]

def _append(self, event: Union[Events, _SingleEventState]) -> None:
if not isinstance(event, (Events, _SingleEventState)):
raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}")
self._events.append(event)

def __getitem__(self, item: int) -> Union[Events, _SingleEventState]:
return self._events[item]

def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]:
return iter(self._events)

def __len__(self) -> int:
return len(self._events)

def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList":
self._append(event=other)
return self


class Filter:
def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None):
r"""
通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率;

:param every: 表示一个函数隔多少次运行一次;
:param once: 表示一个函数只在第多少次时运行一次;
:param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和
`self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态;
"""
if (every is None) and (once is None) and (filter_fn is None):
raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.")

if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)):
raise ValueError("These three values should be only set one.")

if (filter_fn is not None) and not callable(filter_fn):
raise TypeError("Argument event_filter should be a callable")

if (every is not None) and not (isinstance(every, int) and every > 0):
raise ValueError("Argument every should be integer and greater than zero")

if (once is not None) and not (isinstance(once, int) and once > 0):
raise ValueError("Argument once should be integer and positive")

# 设置变量,包括全局变量;
self.num_called = 0
self.num_executed = 0

if every is not None:
self._every = every
self._filter = self.every_filter
elif once is not None:
self._once = once
self._filter = self.once_filter
else:
self._filter = filter_fn

def __call__(self, fn: Callable):

@wraps(fn)
def wrapper(*args, **kwargs) -> Callable:
self.num_called += 1

# 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer;
trainer = args[0]
if self._filter(self, trainer):
self.num_executed += 1
return fn(*args, **kwargs)

wrapper.__fastNLP_filter__ = self
return wrapper

def every_filter(self, *args):
return self.num_called % self._every == 0

def once_filter(self, *args):
return self.num_called == self._once

def state_dict(self) -> Dict:
r"""
通过该函数来保存该 `Filter` 的状态;
"""
return {"num_called": self.num_called, "num_executed": self.num_executed}

def load_state_dict(self, state: Dict):
r"""
通过该函数来加载 `Filter` 的状态;

:param state: 通过 `Filter.state_dict` 函数保存的状态元组;
"""
self.num_called = state["num_called"]
self.num_executed = state["num_executed"]








+ 3
- 3
fastNLP/core/callbacks/callback_manager.py View File

@@ -6,7 +6,7 @@ __all__ = [
'CallbackManager'
]

from .callback_events import Events
from .callback_event import Event
from .callback import Callback
from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback
@@ -110,7 +110,7 @@ class CallbackManager:
def initialize_class_callbacks(self):
r"""
在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是
一个个 callback 时机,也就是 `Events` 的类别;
一个个 callback 时机,也就是 `Event` 的类别;
如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中;

:param callbacks:
@@ -127,7 +127,7 @@ class CallbackManager:
:param callback: 一个具体的 callback 实例;
"""
self.all_callbacks.append(callback)
for name, member in Events.__members__.items():
for name, member in Event.__members__.items():
_fn = getattr(callback, member.value)
if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)):
self.callback_fns[member.value].append(_fn)


+ 17
- 1
fastNLP/core/collators/__init__.py View File

@@ -1,4 +1,20 @@
__all__ = [
'Collator'
'Collator',

'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",
"Padder",
"NullPadder",
"RawNumberPadder",
"RawSequencePadder",
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',
"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",
"get_padded_numpy_array",
]
from .collator import Collator
from .padders import *

+ 30
- 0
fastNLP/core/collators/padders/__init__.py View File

@@ -0,0 +1,30 @@

__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder",

"Padder",
"NullPadder",

"RawNumberPadder",
"RawSequencePadder",

'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder',

"PaddleNumberPadder",
"PaddleTensorPadder",
"PaddleSequencePadder",

"get_padded_numpy_array",
]


from .numpy_padder import *
from .padder import Padder, NullPadder
from .raw_padder import *
from .torch_padder import *
from .paddle_padder import *
from .utils import get_padded_numpy_array

+ 6
- 25
fastNLP/core/collators/padders/get_padder.py View File

@@ -1,8 +1,3 @@

from typing import Dict



from typing import Sequence, Any, Union, Dict
from abc import ABC

@@ -93,6 +88,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).")

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
@@ -103,6 +100,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")

if depth == 1 and shape_len != 0:
if backend == 'numpy':
@@ -111,6 +110,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."
@@ -179,23 +180,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
return catalog




"""
from numbers import Number

issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)
"""




+ 4
- 1
fastNLP/core/collators/padders/raw_padder.py View File

@@ -1,4 +1,7 @@

__all__ = [
"RawNumberPadder",
"RawSequencePadder"
]

from .padder import Padder
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number


+ 5
- 1
fastNLP/core/collators/padders/torch_padder.py View File

@@ -1,4 +1,8 @@

__all__ = [
'TorchNumberPadder',
'TorchSequencePadder',
'TorchTensorPadder'
]
from inspect import isclass
import numpy as np



+ 5
- 1
fastNLP/core/collators/padders/utils.py View File

@@ -1,6 +1,10 @@

__all__ = [
'get_padded_numpy_array'
]


from typing import Sequence, List
from numbers import Number
import re
from inspect import isclass



+ 0
- 2
fastNLP/core/controllers/__init__.py View File

@@ -2,8 +2,6 @@ __all__ = [
'Loop',
'EvaluateBatchLoop',
'TrainBatchLoop',
'State',
'TrainerState',
'Evaluator',
'Trainer',
]


+ 58
- 8
fastNLP/core/controllers/trainer.py View File

@@ -17,10 +17,10 @@ from .utils import State, TrainerState
from .utils.utils import check_evaluate_every
from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList
from fastNLP.core.callbacks import Callback, CallbackManager
from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_manager import prepare_callbacks
from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.callbacks.callback_event import Event
from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
@@ -398,7 +398,7 @@ class Trainer(TrainerEventTrigger):
if self.cur_epoch_idx % evaluate_every == 0:
self.run_evaluate()

def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
def add_callback_fn(self, event: Event, fn: Callable):
r"""
在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数;
这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数;
@@ -406,19 +406,69 @@ class Trainer(TrainerEventTrigger):
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
:param fn: 具体的 callback 函数;
"""
if not isinstance(event, (_SingleEventState, EventsList)):
raise ValueError("parameter event should only be `Events` or `EventsList` type.")
if not isinstance(event, Event):
raise ValueError("parameter event should only be `Event` type.")

_custom_callback = _CallbackWrapper(event, fn)
self.callback_manager.dissect_one_callback(_custom_callback)

@classmethod
def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None):
def on(cls, event: Event, marker: Optional[str] = None):
r"""
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制;
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如
Trainer.__init__():
on_after_trainer_initialized(trainer, driver)
Trainer.run():
if num_eval_sanity_batch>0:
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)
try:
on_train_begin(trainer)
while cur_epoch_idx < n_epochs:
on_train_epoch_begin(trainer)
while batch_idx_in_epoch<=num_batches_per_epoch:
on_fetch_data_begin(trainer)
batch = next(dataloader)
on_fetch_data_end(trainer)
on_train_batch_begin(trainer, batch, indices)
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。
on_after_backward(trainer)
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响
on_train_batch_end(trainer)
on_train_epoch_end(trainer)
except BaseException:
self.on_exception(trainer, exception)
finally:
on_train_end(trainer)
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中
特定的时间调用。

Example::
from fastNLP import Event
@Trainer.on(Event.on_save_model())
def do_something_1(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行。

@Trainer.on(Event.on_save_model(once=True))
def do_something_2(trainer):
# do something
# 以上函数会在 Trainer 保存模型时执行,但只执行一次。

@Trainer.on(Event.on_train_batch_begin(every=2))
def do_something_3(trainer, batch, indices):
# do something
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。

注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前;

:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;
:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含
特定的参数,可以通过上述说明查阅。
:param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时,
表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer
实例使用;
@@ -426,9 +476,9 @@ class Trainer(TrainerEventTrigger):
"""

def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn))
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
_check_valid_parameters_number(fn, callback_fn_args)
cls._custom_callbacks[marker].append((event, fn))
return fn

return wrapper


+ 0
- 9
fastNLP/core/dataset/dataset.py View File

@@ -770,15 +770,6 @@ class DataSet:
df = self.to_pandas()
return df.to_csv(path, encoding="utf-8")

def set_ignore(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self.collator.set_ignore(*field_names)

@property
def collator(self) -> Collator:
if self._collator is None:


+ 0
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -55,7 +55,6 @@ class ReproducibleBatchSampler:


class ReproduceBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
可以使得 batch_sampler 对象状态恢复的 wrapper 。


+ 2
- 0
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet

class ReproducibleSampler:
"""
可复现的 Sampler 对象。

注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。



+ 5
- 0
fastNLP/core/utils/paddle_utils.py View File

@@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]):
else:
return data.cuda(get_paddle_device_id(device))


def get_paddle_gpu_str(device: Union[str, int]):
"""
获得 `gpu:x` 类型的设备名
@@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]):
return device.replace("cuda", "gpu")
return f"gpu:{device}"


def get_paddle_device_id(device: Union[str, int]):
"""
获得 gpu 的设备id
@@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None,

return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to)


def is_in_paddle_dist():
"""
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断
"""
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)


def is_in_fnlp_paddle_dist():
"""
判断是否处于 FastNLP 拉起的分布式进程中
"""
return FASTNLP_DISTRIBUTED_CHECK in os.environ


def is_in_paddle_launch_dist():
"""
判断是否处于 launch 启动的分布式进程中


+ 35
- 4
fastNLP/io/data_bundle.py View File

@@ -332,13 +332,44 @@ class DataBundle:
show_progress_bar=show_progress_bar, progress_desc=progress_desc)
return res

def set_pad_val(self, *field_names, val=0) -> None:
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: self
"""
for _, ds in self.iter_datasets():
ds.set_pad_val(*field_names, val=val)
ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend,
pad_fn=pad_fn)
return self

def set_input(self, *field_names) -> None:
def set_ignore(self, *field_names) -> "DataBundle":
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: self
"""
for _, ds in self.iter_datasets():
ds.set_input(*field_names)
ds.collator.set_ignore(*field_names)
return self

def __repr__(self) -> str:
_str = ''


tests/core/callbacks/test_callback_events.py → tests/core/callbacks/test_callback_event.py View File

@@ -1,7 +1,7 @@
import pytest
from functools import reduce

from fastNLP.core.callbacks.callback_events import Events, Filter
from fastNLP.core.callbacks.callback_event import Event, Filter




+ 2
- 2
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -216,9 +216,9 @@ def test_model_checkpoint_callback_2(
path = Path.cwd().joinpath("test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)

from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event

@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer):
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4:
raise NotImplementedError


+ 287
- 75
tests/core/collators/test_collator.py View File

@@ -1,81 +1,293 @@

import numpy as np
import pytest

from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.dataset import DataSet
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)


class TestCollator:

@pytest.mark.parametrize('as_numpy', [True, False])
def test_auto_collator(self, as_numpy):
"""
测试auto_collator的auto_pad功能

:param as_numpy:
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=as_numpy)
collator.set_input('x', 'y')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
results = []
for bucket in bucket_data:
res = collator(bucket)
assert res['x'].shape == (40, 5)
assert res['y'].shape == (40,)
results.append(res)

def test_auto_collator_v1(self):
"""
测试auto_collator的set_pad_val和set_pad_val功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
collator.set_input('x')
collator.set_pad_val('x', val=-1)
collator.set_as_numpy(True)
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = collator(bucket)
print(res)

def test_multicollator(self):
"""
测试multicollator功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
multi_collator = _MultiCollator(collator)
multi_collator.set_as_numpy(as_numpy=True)
multi_collator.set_pad_val('x', val=-1)
multi_collator.set_input('x')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = multi_collator(bucket)
print(res)
@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))

def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)







+ 0
- 293
tests/core/collators/test_new_collator.py View File

@@ -1,293 +0,0 @@

import numpy as np
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)


class TestCollator:

@pytest.mark.torch
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))

def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)







+ 35
- 35
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -7,7 +7,7 @@ from torchmetrics import Accuracy
import torch.distributed as dist

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback
@@ -62,10 +62,9 @@ def model_and_optimizers():

return trainer_params

@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters,
@@ -97,12 +96,13 @@ def test_trainer_event_trigger_1(
if dist.is_initialized():
dist.destroy_process_group()

for name, member in Events.__members__.items():
assert member.value in output[0]

Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]

@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7])
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
@@ -111,86 +111,86 @@ def test_trainer_event_trigger_2(
n_epochs=2,
):

@Trainer.on(Events.on_after_trainer_initialized())
@Trainer.on(Event.on_after_trainer_initialized())
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Events.on_sanity_check_begin())
@Trainer.on(Event.on_sanity_check_begin())
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Events.on_sanity_check_end())
@Trainer.on(Event.on_sanity_check_end())
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Events.on_train_begin())
@Trainer.on(Event.on_train_begin())
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Events.on_train_end())
@Trainer.on(Event.on_train_end())
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Events.on_train_epoch_begin())
@Trainer.on(Event.on_train_epoch_begin())
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Events.on_train_epoch_end())
@Trainer.on(Event.on_train_epoch_end())
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Events.on_fetch_data_begin())
@Trainer.on(Event.on_fetch_data_begin())
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Events.on_fetch_data_end())
@Trainer.on(Event.on_fetch_data_end())
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Events.on_train_batch_begin())
@Trainer.on(Event.on_train_batch_begin())
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Events.on_train_batch_end())
@Trainer.on(Event.on_train_batch_end())
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Events.on_exception())
@Trainer.on(Event.on_exception())
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Events.on_before_backward())
@Trainer.on(Event.on_before_backward())
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Events.on_after_backward())
@Trainer.on(Event.on_after_backward())
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Events.on_before_optimizers_step())
@Trainer.on(Event.on_before_optimizers_step())
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Events.on_after_optimizers_step())
@Trainer.on(Event.on_after_optimizers_step())
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Events.on_before_zero_grad())
@Trainer.on(Event.on_before_zero_grad())
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Events.on_after_zero_grad())
@Trainer.on(Event.on_after_zero_grad())
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Events.on_evaluate_begin())
@Trainer.on(Event.on_evaluate_begin())
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Events.on_evaluate_end())
@Trainer.on(Event.on_evaluate_end())
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

@@ -215,8 +215,10 @@ def test_trainer_event_trigger_2(
if dist.is_initialized():
dist.destroy_process_group()

for name, member in Events.__members__.items():
assert member.value in output[0]
Event_attrs = Event.__dict__
for k, v in Event_attrs.items():
if isinstance(v, staticmethod):
assert k in output[0]


@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)])
@@ -235,15 +237,15 @@ def test_trainer_event_trigger_3(
once_message_3 = "once message 3"
twice_message = "twice message hei hei"

@Trainer.on(Events.on_train_epoch_begin(every=2))
@Trainer.on(Event.on_train_epoch_begin(every=2))
def train_epoch_begin_1(trainer):
print(once_message_1)

@Trainer.on(Events.on_train_epoch_begin())
@Trainer.on(Event.on_train_epoch_begin())
def train_epoch_begin_2(trainer):
print(twice_message)

@Trainer.on(Events.on_train_epoch_begin(once=2))
@Trainer.on(Event.on_train_epoch_begin(once=2))
def train_epoch_begin_3(trainer):
print(once_message_3)

@@ -253,11 +255,10 @@ def test_trainer_event_trigger_3(
else:
return False

@Trainer.on(Events.on_train_epoch_end(filter_fn=filter_fn))
@Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn))
def test_filter_fn(trainer):
print(once_message_2)


with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
@@ -308,4 +309,3 @@ def test_trainer_event_trigger_3(





+ 2
- 2
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -257,9 +257,9 @@ def test_trainer_on_exception(
cur_rank,
n_epochs=2,
):
from fastNLP.core.callbacks.callback_events import Events
from fastNLP.core.callbacks.callback_event import Event

@Trainer.on(Events.on_train_epoch_end)
@Trainer.on(Event.on_train_epoch_end())
def raise_exception(trainer):
if trainer.driver.get_local_rank() == cur_rank:
raise NotImplementedError


Loading…
Cancel
Save