diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 5cc765b9..f1421c38 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -1,4 +1,53 @@ __all__ = [ + # callbacks + 'Callback', + 'Event', + 'Filter', + 'CallbackManager', + 'CheckpointCallback', + 'choose_progress_callback', + 'ProgressCallback', + 'RichCallback', + "LRSchedCallback", + 'LoadBestModelCallback', + "EarlyStopCallback", + 'MoreEvaluateCallback', + "TorchWarmupCallback", + "TorchGradClipCallback", + + # collators + 'Collator', + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + "Padder", + "NullPadder", + "RawNumberPadder", + "RawSequencePadder", + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + "get_padded_numpy_array", + + # controllers + 'Loop', + 'EvaluateBatchLoop', + 'TrainBatchLoop', + 'Evaluator', + 'Trainer', + + # dataloaders TODO 需要把 mix_dataloader 的搞定 + + # dataset + 'DataSet', + 'FieldArray', + 'Instance', + 'ApplyResultException', + + # drivers "TorchSingleDriver", "TorchDDPDriver", "PaddleSingleDriver", @@ -7,16 +56,16 @@ __all__ = [ "JittorMPIDriver", "TorchPaddleDriver", - "paddle_to", - "get_paddle_gpu_str", - "get_paddle_device_id", - "paddle_move_data_to_device", - "torch_paddle_move_data_to_device", -] -# TODO:之后要优化一下这里的导入,应该是每一个 sub module 先import自己内部的类和函数,然后外层的 module 再直接从 submodule 中 import; -from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.controllers.evaluator import Evaluator -from fastNLP.core.dataloaders.torch_dataloader import * + # log + "logger" + # +] +from .callbacks import * +from .collators import * +from .controllers import * +from .dataloaders import * +from .dataset import * from .drivers import * +from .log import * from .utils import * \ No newline at end of file diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index bbce73e0..cfda1763 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -1,7 +1,6 @@ __all__ = [ 'Callback', - 'Events', - 'EventsList', + 'Event', 'Filter', 'CallbackManager', 'CheckpointCallback', @@ -20,7 +19,7 @@ __all__ = [ from .callback import Callback -from .callback_events import EventsList, Events, Filter +from .callback_event import Event, Filter from .callback_manager import CallbackManager from .checkpoint_callback import CheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 7f0c290d..6b5c74fc 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -3,10 +3,9 @@ __all__ = [ 'Callback', ] -from typing import Union, Callable, Dict, Optional, Any +from typing import Callable, Dict, Optional -from .callback_events import Events, EventsList, Filter -from fastNLP.core.callbacks.callback_events import _SingleEventState +from .callback_event import Event, Filter class Callback: @@ -14,32 +13,35 @@ class Callback: 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; callback 调用时机顺序大概如下 Trainer.__init__(): - on_after_trainer_initialized() + on_after_trainer_initialized(trainer, driver) Trainer.run(): if num_eval_sanity_batch>0: - on_sanity_check_begin() # 如果设置了num_eval_sanity_batch - on_sanity_check_end() + on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch + on_sanity_check_end(trainer, sanity_check_res) try: - on_train_begin() + on_train_begin(trainer) while cur_epoch_idx < n_epochs: - on_train_epoch_begin() + on_train_epoch_begin(trainer) while batch_idx_in_epoch<=num_batches_per_epoch: - on_fetch_data_begin() - on_fetch_data_end() - on_train_batch_begin() - on_before_backward() - on_after_backward() - on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 - on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 - on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 - on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 - on_train_batch_end() - on_train_epoch_end() + on_fetch_data_begin(trainer) + batch = next(dataloader) + on_fetch_data_end(trainer) + on_train_batch_begin(trainer, batch, indices) + on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 + on_after_backward(trainer) + on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_train_batch_end(trainer) + on_train_epoch_end(trainer) except BaseException: - self.on_exception() + self.on_exception(trainer, exception) finally: - on_train_end() - 其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 + on_train_end(trainer) + 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ + on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中特定 + 的时间调用。 """ def on_after_trainer_initialized(self, trainer, driver): @@ -294,18 +296,14 @@ class _CallbackWrapper(Callback): 对于用户使用函数修饰器加入的 callback 函数,使用该 _CallbackWrapper 类为其进行定制,这一个类只保留用户的 这一个 callback 函数; """ - def __init__(self, event: Union[Events, EventsList], fn: Callable): + def __init__(self, event: Event, fn: Callable): r""" - :param event: 具体的 callback 时机,例如 'on_train_begin' 等;可以多个时机,此时 `event` 的 type 应当为 'EventsList'; + :param event: 具体的 callback 时机,例如 'on_train_begin' 等; :param fn: 用户定制的 callback 函数; """ self.fn = fn - if isinstance(event, EventsList): - for each_event in event: - _filter = Filter(each_event.every, each_event.once, each_event.filter_fn) - setattr(self, each_event.value, _filter(fn)) - elif isinstance(event, _SingleEventState): + if isinstance(event, Event): _filter = Filter(event.every, event.once, event.filter_fn) setattr(self, event.value, _filter(fn)) diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py new file mode 100644 index 00000000..b3088a66 --- /dev/null +++ b/fastNLP/core/callbacks/callback_event.py @@ -0,0 +1,499 @@ +from typing import Optional, Callable, Dict +from functools import wraps + + +__all__ = [ + 'Event', + 'Filter' +] + + +def check_legality(fn): + @wraps(fn) + def wrap(every=None, once=None, filter_fn=None): + if (every is None) and (once is None) and (filter_fn is None): + every = 1 + + if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): + raise ValueError("These three values should be only set one.") + + if (filter_fn is not None) and not callable(filter_fn): + raise TypeError("Argument filter_fn should be a callable") + + if (every is not None) and not (isinstance(every, int) and every > 0): + raise ValueError("Argument every should be integer and greater than zero") + + if (once is not None) and not (isinstance(once, int) and once > 0): + raise ValueError("Argument once should be integer and positive") + return fn(every=every, once=once, filter_fn=filter_fn) + return wrap + + +class Event: + every: Optional[int] + once: Optional[int] + + def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, + filter_fn: Optional[Callable] = None): + """ + 请勿直接使用本对象,而是通过调用 Event.on_after_trainer_initialized() 等方式调用。 + + :param value: Trainer 的 callback 时机。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + """ + self.every = every + self.once = once + self.filter_fn = filter_fn + self.value = value + + def __str__(self): + return "".format(self.value, self.every, self.once, + self.filter_fn) + @staticmethod + @check_legality + def on_after_trainer_initialized(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_trainer_initialized 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。默认为 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_sanity_check_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_sanity_check_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_sanity_check_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_sanity_check_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_epoch_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_epoch_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_epoch_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_epoch_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_fetch_data_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_fetch_data_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_fetch_data_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_fetch_data_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_batch_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_batch_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_train_batch_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_train_batch_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_exception(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_exception 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_save_model(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_save_model 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_model(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_model 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_save_checkpoint(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_save_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_checkpoint(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_load_checkpoint(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_load_checkpoint 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_backward(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_backward 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_backward(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_backward 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_optimizers_step(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_optimizers_step 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_optimizers_step(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_optimizers_step 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_before_zero_grad(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_before_zero_grad 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_after_zero_grad(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_after_zero_grad 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_evaluate_begin(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_evaluate_begin 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) + + @staticmethod + @check_legality + def on_evaluate_end(every=None, once=None, filter_fn=None): + """ + 当 Trainer 运行到 on_evaluate_end 时 + + 以下三个参数互斥,只能设置其中一个。默认为行为等同于 every=1 。 + :param int every: 触发了多少次,才真正运行一次。 + :param bool once: 是否只在第一次运行后就不再执行了。 + :param Callable filter_fn: 输入参数的应该为 (filter, trainer),其中 filter 对象中包含了 filter.num_called 和 + filter.num_executed 两个变了分别获取当前被调用了多少次,真正执行了多少次。trainer 对象即为当前正在运行的 Trainer 。 + :return: + """ + return Event(value='on_evaluate_end', every=every, once=once, filter_fn=filter_fn) + + +class Filter: + def __init__(self, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): + r""" + 通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; + + :param every: 表示一个函数隔多少次运行一次; + :param once: 表示一个函数只运行一次; + :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 + `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; + """ + # check legality + check_legality(lambda *args,**kwargs:...)(every, once, filter_fn) + if (every is None) and (once is None) and (filter_fn is None): + every = 1 + # 设置变量,包括全局变量; + self.num_called = 0 + self.num_executed = 0 + + if every is not None: + self._every = every + self._filter = self.every_filter + elif once is not None: + self._once = once + self._filter = self.once_filter + else: + self._filter = filter_fn + + def __call__(self, fn: Callable): + + @wraps(fn) + def wrapper(*args, **kwargs) -> Callable: + self.num_called += 1 + + # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; + trainer = args[0] + if self._filter(self, trainer): + self.num_executed += 1 + return fn(*args, **kwargs) + + wrapper.__fastNLP_filter__ = self + return wrapper + + def every_filter(self, *args): + return self.num_called % self._every == 0 + + def once_filter(self, *args): + return self.num_called == self._once + + def state_dict(self) -> Dict: + r""" + 通过该函数来保存该 `Filter` 的状态; + """ + return {"num_called": self.num_called, "num_executed": self.num_executed} + + def load_state_dict(self, state: Dict): + r""" + 通过该函数来加载 `Filter` 的状态; + + :param state: 通过 `Filter.state_dict` 函数保存的状态元组; + """ + self.num_called = state["num_called"] + self.num_executed = state["num_executed"] + + + + + + + diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py deleted file mode 100644 index 7252398c..00000000 --- a/fastNLP/core/callbacks/callback_events.py +++ /dev/null @@ -1,206 +0,0 @@ -from enum import Enum, unique -from typing import Union, Optional, List, Iterator, Callable, Tuple, Dict -from types import DynamicClassAttribute -from functools import wraps - - -__all__ = [ - 'Events', - 'EventsList', - 'Filter' -] - - -class _SingleEventState: - every: Optional[int] - once: Optional[int] - - def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, - filter_fn: Optional[Callable] = None, name: Optional[str] = None): - - # 具体的检测参数对错的逻辑放在具体的 Filter 里; - if every is None and once is None and filter_fn is None: - self.every = 1 - self.once = None - self.filter_fn = None - else: - self.every = every - self.once = once - self.filter_fn = filter_fn - - if not hasattr(self, "_value_"): - self._value_ = value - - if not hasattr(self, "_name_") and name is not None: - self._name_ = name - - # copied to be compatible to enum - @DynamicClassAttribute - def name(self) -> str: - """The name of the Enum member.""" - return self._name_ - - @DynamicClassAttribute - def value(self) -> str: - """The value of the Enum member.""" - return self._value_ - - def __call__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): - return _SingleEventState(self.value, every, once, filter_fn, self.name) - - def __str__(self): - return "".format(self.name, self.every, self.once, - self.filter_fn) - - def __eq__(self, other) -> bool: - if isinstance(other, _SingleEventState): - return self.name == other.name - elif isinstance(other, str): - return self.name == other - else: - raise NotImplemented - - def __hash__(self): - return hash(self._name_) - - def __or__(self, other) -> "EventsList": - return EventsList() | self | other - - -class EventEnum(_SingleEventState, Enum): - pass - -@unique -class Events(EventEnum): - on_after_trainer_initialized = "on_after_trainer_initialized" - on_sanity_check_begin = "on_sanity_check_begin" - on_sanity_check_end = "on_sanity_check_end" - on_train_begin = "on_train_begin" - on_train_end = "on_train_end" - on_train_epoch_begin = "on_train_epoch_begin" - on_train_epoch_end = "on_train_epoch_end" - on_fetch_data_begin = "on_fetch_data_begin" - on_fetch_data_end = "on_fetch_data_end" - on_train_batch_begin = "on_train_batch_begin" - on_train_batch_end = "on_train_batch_end" - on_exception = "on_exception" - on_save_model = "on_save_model" - on_load_model = "on_load_model" - on_save_checkpoint = "on_save_checkpoint" - on_load_checkpoint = "on_load_checkpoint" - on_before_backward = "on_before_backward" - on_after_backward = "on_after_backward" - on_before_optimizers_step = "on_before_optimizers_step" - on_after_optimizers_step = "on_after_optimizers_step" - on_before_zero_grad = "on_before_zero_grad" - on_after_zero_grad = "on_after_zero_grad" - on_evaluate_begin = "on_evaluate_begin" - on_evaluate_end = "on_evaluate_end" - - -class EventsList: - """Collection of events stacked by operator `__or__`. - """ - - def __init__(self) -> None: - self._events = [] # type: List[Union[Events, _SingleEventState]] - - def _append(self, event: Union[Events, _SingleEventState]) -> None: - if not isinstance(event, (Events, _SingleEventState)): - raise TypeError(f"Argument event should be Events or CallableEventWithFilter, got: {type(event)}") - self._events.append(event) - - def __getitem__(self, item: int) -> Union[Events, _SingleEventState]: - return self._events[item] - - def __iter__(self) -> Iterator[Union[Events, _SingleEventState]]: - return iter(self._events) - - def __len__(self) -> int: - return len(self._events) - - def __or__(self, other: Union[Events, _SingleEventState]) -> "EventsList": - self._append(event=other) - return self - - -class Filter: - def __init__(self, every: Optional[int] = None, once: Optional[int] = None, filter_fn: Optional[Callable] = None): - r""" - 通过该 `Filter` 作为函数修饰器来控制一个函数的实际的运行频率; - - :param every: 表示一个函数隔多少次运行一次; - :param once: 表示一个函数只在第多少次时运行一次; - :param filter_fn: 用户定制的频率控制函数;注意该函数内部的频率判断应当是无状态的,除了参数 `self.num_called` 和 - `self.num_executed` 外,因为我们会在预跑后重置这两个参数的状态; - """ - if (every is None) and (once is None) and (filter_fn is None): - raise ValueError("If you mean your decorated function should be called every time, you do not need this filter.") - - if not ((every is not None) ^ (once is not None) ^ (filter_fn is not None)): - raise ValueError("These three values should be only set one.") - - if (filter_fn is not None) and not callable(filter_fn): - raise TypeError("Argument event_filter should be a callable") - - if (every is not None) and not (isinstance(every, int) and every > 0): - raise ValueError("Argument every should be integer and greater than zero") - - if (once is not None) and not (isinstance(once, int) and once > 0): - raise ValueError("Argument once should be integer and positive") - - # 设置变量,包括全局变量; - self.num_called = 0 - self.num_executed = 0 - - if every is not None: - self._every = every - self._filter = self.every_filter - elif once is not None: - self._once = once - self._filter = self.once_filter - else: - self._filter = filter_fn - - def __call__(self, fn: Callable): - - @wraps(fn) - def wrapper(*args, **kwargs) -> Callable: - self.num_called += 1 - - # 因为我们的 callback 函数的输入是固定的,而且我们能够保证第一个参数一定是 trainer; - trainer = args[0] - if self._filter(self, trainer): - self.num_executed += 1 - return fn(*args, **kwargs) - - wrapper.__fastNLP_filter__ = self - return wrapper - - def every_filter(self, *args): - return self.num_called % self._every == 0 - - def once_filter(self, *args): - return self.num_called == self._once - - def state_dict(self) -> Dict: - r""" - 通过该函数来保存该 `Filter` 的状态; - """ - return {"num_called": self.num_called, "num_executed": self.num_executed} - - def load_state_dict(self, state: Dict): - r""" - 通过该函数来加载 `Filter` 的状态; - - :param state: 通过 `Filter.state_dict` 函数保存的状态元组; - """ - self.num_called = state["num_called"] - self.num_executed = state["num_executed"] - - - - - - - diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 2b8fff60..7b04d8ad 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -6,7 +6,7 @@ __all__ = [ 'CallbackManager' ] -from .callback_events import Events +from .callback_event import Event from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback @@ -110,7 +110,7 @@ class CallbackManager: def initialize_class_callbacks(self): r""" 在实际的运行过程中,我们是将具体的一个 callback 实例拆分为单独的一个个 callback 函数,然后将它们加在一个字典里,该字典的键值就是 - 一个个 callback 时机,也就是 `Events` 的类别; + 一个个 callback 时机,也就是 `Event` 的类别; 如果一个 callback 类的 callback 函数并不具备任何作用,我们实际并不会将其加在字典当中; :param callbacks: @@ -127,11 +127,12 @@ class CallbackManager: :param callback: 一个具体的 callback 实例; """ self.all_callbacks.append(callback) - for name, member in Events.__members__.items(): - _fn = getattr(callback, member.value) - if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, member.value)): - self.callback_fns[member.value].append(_fn) - self.extract_callback_filter_state(callback.callback_name, _fn) + for name, member in Event.__dict__.items(): + if isinstance(member, staticmethod): + _fn = getattr(callback, name) + if inspect.getsource(_fn) != inspect.getsource(getattr(Callback, name)): + self.callback_fns[name].append(_fn) + self.extract_callback_filter_state(callback.callback_name, _fn) def extract_callback_filter_state(self, callback_name, callback_fn): r""" diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 52214ff0..c5c5edde 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -161,7 +161,6 @@ class MonitorUtility: return monitor_name - class HasMonitorCallback(MonitorUtility, Callback): def __init__(self, monitor, larger_better, must_have_monitor=False): """ diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py index 17cbb6ae..1e508689 100644 --- a/fastNLP/core/collators/__init__.py +++ b/fastNLP/core/collators/__init__.py @@ -1,4 +1,20 @@ __all__ = [ - 'Collator' + 'Collator', + + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + "Padder", + "NullPadder", + "RawNumberPadder", + "RawSequencePadder", + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + "get_padded_numpy_array", ] from .collator import Collator +from .padders import * diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index ceb50a29..5c5abda4 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -65,12 +65,16 @@ def _get_backend() -> str: return catch_backend[0] # 方式 (2) + for backend in CHECK_BACKEND: + if backend in sys.modules: + logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") + return backend for key, module in sys.modules.items(): catch_backend = _check_module(module) if catch_backend: break if len(catch_backend): - logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") + logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.") return catch_backend[0] return 'numpy' @@ -227,7 +231,7 @@ class Collator: 设置可以 pad 的 field 默认 pad 为什么类型的 tensor :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], - 若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 + 若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。 :return: """ assert backend in SUPPORTED_BACKENDS diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py index e69de29b..09a5ca8d 100644 --- a/fastNLP/core/collators/padders/__init__.py +++ b/fastNLP/core/collators/padders/__init__.py @@ -0,0 +1,30 @@ + +__all__ = [ + 'NumpyNumberPadder', + 'NumpySequencePadder', + "NumpyTensorPadder", + + "Padder", + "NullPadder", + + "RawNumberPadder", + "RawSequencePadder", + + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder', + + "PaddleNumberPadder", + "PaddleTensorPadder", + "PaddleSequencePadder", + + "get_padded_numpy_array", +] + + +from .numpy_padder import * +from .padder import Padder, NullPadder +from .raw_padder import * +from .torch_padder import * +from .paddle_padder import * +from .utils import get_padded_numpy_array \ No newline at end of file diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 3e136d7d..1df0c0d8 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -1,8 +1,3 @@ - -from typing import Dict - - - from typing import Sequence, Any, Union, Dict from abc import ABC @@ -12,7 +7,7 @@ from fastNLP.core.log import logger from .padder import Padder, NullPadder from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder -from .raw_padder import RawNumberPadder, RawSequencePadder +from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from .exceptions import * @@ -28,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> :param field_name: 方便报错的。 :return: """ - + assert len(batch_field)!=0, "Empty batch encountered." logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) if pad_val is None: logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") @@ -68,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return NullPadder() # 再检查所有的元素 type 是否一致 - ele_dtypes = set([v[1] for v in catalog.values()]) + try: + ele_dtypes = set([v[1] for v in catalog.values()]) + except TypeError: + ele_dtypes = set([str(v[1]) for v in catalog.values()]) num_eletypes = len(ele_dtypes) if num_eletypes != 1: msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ @@ -80,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> depth = depths.pop() shape_len = shape_lens.pop() - ele_dtype = ele_dtypes.pop() + ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了 # 需要由 padder 自己决定是否能够 pad 。 try: @@ -93,6 +91,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'paddle': return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 if backend == 'raw': @@ -103,14 +103,21 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'paddle': return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") - if depth == 1 and shape_len != 0: - if backend == 'numpy': - return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + # 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行 + if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)): + if backend == 'raw': + return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) + elif backend == 'numpy': + return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'torch': - return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) elif backend == 'paddle': - return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) + return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) + else: + raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") if shape_len != 0 and depth>1: msg = "Does not support pad tensor under nested list. If you need this, please report." @@ -179,23 +186,3 @@ def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 return catalog - - - - -""" -from numbers import Number - -issubclass(type(3), Number) # True -issubclass(type(3.1), Number) # True -issubclass(type('3'), Number) # False -issubclass(type(True), Number) # True -issubclass(type(np.zeros(3)[0]), Number) # True -isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True -isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True -isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 -is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) -""" - - - diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py index 40b70683..4d507f2e 100644 --- a/fastNLP/core/collators/padders/numpy_padder.py +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -66,7 +66,7 @@ class NumpySequencePadder(Padder): class NumpyTensorPadder(Padder): def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ - pad 类似于 [np.array([3, 4], np.array([1])] 的 field + pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 :param pad_val: pad 的值是多少。 :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 @@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val, dtype): + try: + if not isinstance(batch_field[0], np.ndarray): + batch_field = [np.array(field.tolist()) for field in batch_field] + except AttributeError: + raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " + f"it must have tolist() method.") + shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] array = np.full(max_shape, fill_value=pad_val, dtype=dtype) diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 7a569003..10d5a385 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype): def _get_dtype(ele_dtype, dtype, class_name): - if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): + if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") @@ -74,13 +74,20 @@ def _get_dtype(ele_dtype, dtype, class_name): elif is_numpy_generic_class(ele_dtype): dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) else: - dtype == ele_dtype + dtype = ele_dtype return dtype class PaddleNumberPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) + + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 + """ # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder): class PaddleSequencePadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, ele_dtype=None, pad_val=0, dtype=None): + """ + 将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + + :param pad_val: pad 的值。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 + """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder): class PaddleTensorPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ - 目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 + 目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 - :param ele_dtype: - :param pad_val: - :param dtype: + :param pad_val: pad 的值。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @staticmethod def pad(batch_field, pad_val, dtype): + try: + if not isinstance(batch_field[0], paddle.Tensor): + batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] + except AttributeError: + raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " + f"it must have tolist() method.") + shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] if isinstance(dtype, np.dtype): @@ -174,6 +195,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): """ shapes = get_shape(batch_field) tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) - # tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val) tensor = fill_tensor(batch_field, tensor, dtype=dtype) return tensor diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py index fe0c0b14..bc3bbaee 100644 --- a/fastNLP/core/collators/padders/raw_padder.py +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -1,4 +1,8 @@ - +__all__ = [ + "RawNumberPadder", + "RawSequencePadder", + "RawTensorPadder" +] from .padder import Padder from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number @@ -63,3 +67,34 @@ class RawSequencePadder(Padder): :return: """ return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() + + +class RawTensorPadder(Padder): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 + + :param pad_val: pad 的值 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 + """ + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + """ + + :param batch_field: + :param pad_val: + :param dtype: 该参数无意义。 + :return: + """ + try: + if not isinstance(batch_field[0], (list, tuple)): + batch_field = [field.tolist() for field in batch_field] + except AttributeError: + raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), " + f"it must have tolist() method.") + + return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index 7c0eaa33..18f414e8 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -1,4 +1,8 @@ - +__all__ = [ + 'TorchNumberPadder', + 'TorchSequencePadder', + 'TorchTensorPadder' +] from inspect import isclass import numpy as np @@ -37,7 +41,7 @@ def is_torch_tensor(dtype): def _get_dtype(ele_dtype, dtype, class_name): - if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): + if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") @@ -97,7 +101,7 @@ class TorchSequencePadder(Padder): class TorchTensorPadder(Padder): def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ - 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 + 目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 :param pad_val: 需要 pad 的值。 :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 @@ -108,6 +112,13 @@ class TorchTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val, dtype): + try: + if not isinstance(batch_field[0], torch.Tensor): + batch_field = [torch.tensor(field.tolist()) for field in batch_field] + except AttributeError: + raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " + f"it must have tolist() method.") + shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py index d2d3a8e0..b322897f 100644 --- a/fastNLP/core/collators/padders/utils.py +++ b/fastNLP/core/collators/padders/utils.py @@ -1,6 +1,10 @@ +__all__ = [ + 'get_padded_numpy_array' +] + + from typing import Sequence, List -from numbers import Number import re from inspect import isclass diff --git a/fastNLP/core/controllers/__init__.py b/fastNLP/core/controllers/__init__.py index 3e93343d..ec47f254 100644 --- a/fastNLP/core/controllers/__init__.py +++ b/fastNLP/core/controllers/__init__.py @@ -2,8 +2,6 @@ __all__ = [ 'Loop', 'EvaluateBatchLoop', 'TrainBatchLoop', - 'State', - 'TrainerState', 'Evaluator', 'Trainer', ] diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5223c9d8..e0cf4b0d 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -17,10 +17,10 @@ from .utils import State, TrainerState from .utils.utils import check_evaluate_every from .evaluator import Evaluator from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader -from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList +from fastNLP.core.callbacks import Callback, CallbackManager from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_manager import prepare_callbacks -from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.callbacks.callback_event import Event from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext @@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger): raise e finally: self.on_train_end() - self.driver.barrier() def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: @@ -399,7 +398,7 @@ class Trainer(TrainerEventTrigger): if self.cur_epoch_idx % evaluate_every == 0: self.run_evaluate() - def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): + def add_callback_fn(self, event: Event, fn: Callable): r""" 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; @@ -407,19 +406,69 @@ class Trainer(TrainerEventTrigger): :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; :param fn: 具体的 callback 函数; """ - if not isinstance(event, (_SingleEventState, EventsList)): - raise ValueError("parameter event should only be `Events` or `EventsList` type.") + if not isinstance(event, Event): + raise ValueError("parameter event should only be `Event` type.") _custom_callback = _CallbackWrapper(event, fn) self.callback_manager.dissect_one_callback(_custom_callback) @classmethod - def on(cls, event: Optional[Union[Events, EventsList]], marker: Optional[str] = None): + def on(cls, event: Event, marker: Optional[str] = None): r""" 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; + 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 + Trainer.__init__(): + on_after_trainer_initialized(trainer, driver) + Trainer.run(): + if num_eval_sanity_batch>0: + on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch + on_sanity_check_end(trainer, sanity_check_res) + try: + on_train_begin(trainer) + while cur_epoch_idx < n_epochs: + on_train_epoch_begin(trainer) + while batch_idx_in_epoch<=num_batches_per_epoch: + on_fetch_data_begin(trainer) + batch = next(dataloader) + on_fetch_data_end(trainer) + on_train_batch_begin(trainer, batch, indices) + on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 + on_after_backward(trainer) + on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 + on_train_batch_end(trainer) + on_train_epoch_end(trainer) + except BaseException: + self.on_exception(trainer, exception) + finally: + on_train_end(trainer) + 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ + on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 + 特定的时间调用。 + + Example:: + from fastNLP import Event + @Trainer.on(Event.on_save_model()) + def do_something_1(trainer): + # do something + # 以上函数会在 Trainer 保存模型时执行。 + + @Trainer.on(Event.on_save_model(once=True)) + def do_something_2(trainer): + # do something + # 以上函数会在 Trainer 保存模型时执行,但只执行一次。 + + @Trainer.on(Event.on_train_batch_begin(every=2)) + def do_something_3(trainer, batch, indices): + # do something + # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 + 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; - :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; + :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 + 特定的参数,可以通过上述说明查阅。 :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer 实例使用; @@ -427,9 +476,9 @@ class Trainer(TrainerEventTrigger): """ def wrapper(fn: Callable) -> Callable: - cls._custom_callbacks[marker].append((event, fn)) callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] _check_valid_parameters_number(fn, callback_fn_args) + cls._custom_callbacks[marker].append((event, fn)) return fn return wrapper @@ -441,6 +490,7 @@ class Trainer(TrainerEventTrigger): """ _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) _own_callbacks.extend(self._custom_callbacks[None]) + logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().") self._custom_callbacks[None] = [] if self.marker is not None: if len(self._custom_callbacks[self.marker]) == 0: diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 787fcb69..507073a4 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -14,7 +14,7 @@ else: from fastNLP.core.dataset import DataSet as Dataset from fastNLP.core.utils.jittor_utils import jittor_collate_wraps from fastNLP.core.collators import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet @@ -107,33 +107,33 @@ class JittorDataLoader: return len(self.dataset) // self.dataset.batch_size return (len(self.dataset) - 1) // self.dataset.batch_size + 1 - def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, - pad_fn: Callable = None) -> "JittorDataLoader": + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, - backend=backend) - return self + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "JittorDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -146,18 +146,17 @@ class JittorDataLoader: """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def get_batch_indices(self) -> List[int]: """ - 获取当前数据的idx + 获取当前 batch 的 idx :return: """ return self.cur_batch_indices - def prepare_jittor_dataloader(): ... diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index b4b675c4..fa99be22 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -15,8 +15,9 @@ else: from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.collators.collator import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataset import DataSet as FDataSet +from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler class _PaddleDataset(Dataset): @@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader): if not isinstance(dataset, _PaddleDataset): dataset = _PaddleDataset(dataset) + if batch_sampler is None: + batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last) + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, @@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader): if isinstance(dataset.dataset, FDataSet): self._collate_fn = dataset.dataset.collator self._collate_fn.set_backend(backend="paddle") - # if collate_fn is not None: - # self._collate_fn.add_collator(collate_fn) else: self._collate_fn = Collator(backend="paddle") @@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader): self.cur_batch_indices = indices yield data - def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, - pad_fn: Callable = None) -> "PaddleDataLoader": + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, - backend=backend) - return self + self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "PaddleDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader): """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def get_batch_indices(self) -> List[int]: """ - 获取当前数据的idx + 获取当前 batch 的 idx :return: """ @@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader): def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, - return_list: bool = True, batch_sampler=None, + return_list: bool = True, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, train_batch_size: int = 1, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, num_workers: int = 0, use_buffer_reader: bool = True, diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 3ee838c4..d008d4ad 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -3,14 +3,14 @@ __all__ = [ 'prepare_torch_dataloader' ] -from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping +from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator -from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler @@ -76,6 +76,10 @@ class TorchDataLoader(DataLoader): if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) + if sampler is None and batch_sampler is None: + sampler = RandomSampler(dataset, shuffle=shuffle) + shuffle=False + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -87,9 +91,6 @@ class TorchDataLoader(DataLoader): if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset self._collate_fn = dataset.dataset.collator self._collate_fn.set_backend(backend="torch") - # if collate_fn is not None and collate_fn is not default_collate: - # # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 - # self._collate_fn.add_collator(collate_fn) else: self._collate_fn = Collator(backend="torch") else: @@ -112,31 +113,32 @@ class TorchDataLoader(DataLoader): yield data def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, - pad_fn:Callable=None) -> "TorchDataLoader": + pad_fn:Callable=None) -> Collator: """ - 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, - paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") - def set_ignore(self, *field_names) -> "TorchDataLoader": + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 Ex:: @@ -149,24 +151,23 @@ class TorchDataLoader(DataLoader): """ if isinstance(self._collate_fn, Collator): self._collate_fn.set_ignore(*field_names) - return self + return self._collate_fn else: - raise ValueError(f"collate_fn is not fastnlp collator") + raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") def get_batch_indices(self) -> List[int]: """ - 获取当前数据的idx + 获取当前 batch 的 idx :return: """ return self.cur_batch_indices - def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, - shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, - batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py new file mode 100644 index 00000000..a71dc50c --- /dev/null +++ b/fastNLP/core/dataloaders/utils.py @@ -0,0 +1,16 @@ +def indice_collate_wrapper(func): + """ + 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 + + :param func: 需要修饰的函数 + :return: + """ + + def wrapper(tuple_data): + indice, ins_list = [], [] + for idx, ins in tuple_data: + indice.append(idx) + ins_list.append(ins) + return indice, func(ins_list) + + return wrapper \ No newline at end of file diff --git a/fastNLP/core/dataloaders/utils/__init__.py b/fastNLP/core/dataloaders/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 0c79bc92..3b9f027e 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -770,17 +770,8 @@ class DataSet: df = self.to_pandas() return df.to_csv(path, encoding="utf-8") - def set_ignore(self, *field_names) -> None: - """ - 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 - - :param field_names: - :return: - """ - self.collator.set_ignore(*field_names) - @property - def collator(self): + def collator(self) -> Collator: if self._collator is None: self._collator = Collator() return self._collator diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index a1275bed..73342748 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -22,7 +22,7 @@ from fastNLP.core.utils import ( rank_zero_rm ) from fastNLP.core.samplers import ( - RandomBatchSampler, + ReproduceBatchSampler, ReproducibleSampler, ReproducibleBatchSampler, RandomSampler, @@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver): return self.model, model.forward - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]], reproducible: bool = False): r""" 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index ed1aad73..f65efd3d 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -22,7 +22,7 @@ from fastNLP.core.log import logger from fastNLP.core.samplers import ( ReproducibleBatchSampler, ReproducibleSampler, - RandomBatchSampler, + ReproduceBatchSampler, RandomSampler, ) @@ -345,7 +345,7 @@ class PaddleDriver(Driver): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") else: - sampler = RandomBatchSampler( + sampler = ReproduceBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -476,7 +476,7 @@ class PaddleDriver(Driver): res.shuffle = True else: res.shuffle = False - # RandomBatchSampler 的情况 + # ReproduceBatchSampler 的情况 elif hasattr(dataloader.batch_sampler, "batch_sampler"): batch_sampler = dataloader.batch_sampler.batch_sampler res.sampler = batch_sampler.sampler diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index f140ad69..52805a97 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -14,7 +14,7 @@ from fastNLP.core.utils import ( from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.samplers import ( ReproducibleBatchSampler, - RandomBatchSampler, + ReproduceBatchSampler, ReproducibleSampler, RandomSampler, re_instantiate_sampler, @@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver): logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) else: - batch_sampler = RandomBatchSampler( + batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 99ba754e..6c125a73 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -15,7 +15,7 @@ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call from fastNLP.core.utils.utils import _get_fun_msg -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler from fastNLP.core.samplers import RandomSampler from fastNLP.core.log import logger @@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver): logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) else: - batch_sampler = RandomBatchSampler( + batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 172a3cf0..8c332251 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler class TorchDriver(Driver): @@ -293,7 +293,7 @@ class TorchDriver(Driver): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") else: - sampler = RandomBatchSampler( + sampler = ReproduceBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -407,7 +407,7 @@ class TorchDriver(Driver): res.shuffle = True else: res.shuffle = False - # RandomBatchSampler 的情况 + # ReproduceBatchSampler 的情况 elif hasattr(dataloader.batch_sampler, "batch_sampler"): batch_sampler = dataloader.batch_sampler.batch_sampler res.sampler = batch_sampler.sampler diff --git a/fastNLP/core/log/print.py b/fastNLP/core/log/print.py new file mode 100644 index 00000000..30797b89 --- /dev/null +++ b/fastNLP/core/log/print.py @@ -0,0 +1,25 @@ +__all__ = [ + 'print' +] + +from .logger import logger + + +def print(*args, sep=' ', end='\n', file=None, flush=False): + """ + 用来重定向 print 函数至 logger.info 的函数。 + + Example: + from fastNLP import print + + print("This is a test") # 等价于调用了 logger.info("This is a test") + + :param args: 需要打印的内容 + :param sep: 存在多个输入时,使用的间隔。 + :param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。 + :param file: 该参数无意义。 + :param flush: 该参数无意义。 + :return: + """ + line = sep.join(args) + logger.info(line) \ No newline at end of file diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index edc1f891..53c29689 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -14,9 +14,10 @@ __all__ = [ "UnrepeatedSortedSampler", "UnrepeatedSequentialSampler", - "RandomBatchSampler", + "ReproduceBatchSampler", "BucketedBatchSampler", "ReproducibleBatchSampler", + "RandomBatchSampler", "re_instantiate_sampler" ] @@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .utils import re_instantiate_sampler from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler -from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler +from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 2bbf409f..88fcb462 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,5 +1,6 @@ __all__ = [ 'BucketedBatchSampler', + "ReproduceBatchSampler", "RandomBatchSampler" ] @@ -7,7 +8,6 @@ import math from copy import deepcopy from typing import Dict, Union, List from itertools import chain -import os import numpy as np @@ -54,13 +54,12 @@ class ReproducibleBatchSampler: raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") -class RandomBatchSampler(ReproducibleBatchSampler): - # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; +class ReproduceBatchSampler(ReproducibleBatchSampler): def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 - :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 :param batch_size: 每个 batch 的大小是多少。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 @@ -143,7 +142,7 @@ class RandomBatchSampler(ReproducibleBatchSampler): self.need_reinitialize = False def set_distributed(self, num_replicas, rank, pad=True): - raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.") + raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") def set_epoch(self, epoch): if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): @@ -158,6 +157,211 @@ class RandomBatchSampler(ReproducibleBatchSampler): (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size +class RandomBatchSampler(ReproducibleBatchSampler): + def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True, + drop_last: bool = False, seed: int = 0, **kwargs): + """ + 随机分 batch 的 batch_sampler 。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param batch_size: 每个 batch 的大小 + :param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 + :param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 + """ + super().__init__() + + self.dataset = dataset + + self.batch_size = batch_size + self.shuffle = shuffle + self.drop_last = drop_last + self.seed = seed + + self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量 + + # 多卡的相关的参数 + self.num_replicas = kwargs.get("num_replicas", 1) + self.rank = kwargs.get("rank", 0) + self.epoch = kwargs.get("epoch", -1) + self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义; + + # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() + self.during_iter = kwargs.get("during_iter", False) + + # 以下变量为内部使用恢复状态的变量。 + self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) + + def set_distributed(self, num_replicas, rank, pad=True): + assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ + "during an unfinished iteration." + assert num_replicas > 0 and isinstance(num_replicas, int) + assert isinstance(rank, int) and 0 <= rank < num_replicas + # 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态; + self.num_replicas = num_replicas + self.rank = rank + self.pad = pad + + return self + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + + indices = list(range(len(self.dataset))) + + if self.shuffle: + if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 + _batches = [] + for _i in range(self.old_num_replicas): + _indices = indices[_i:len(indices):self.old_num_replicas] + __batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch) + _batches.append(__batches) + batches = list(chain(*[_ for _ in zip(*_batches)])) + indices = list(chain(*batches)) + indices = indices[self.num_consumed_samples:] + # 取出这个 rank , + indices = indices[self.rank:len(indices):self.num_replicas] + batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch) + batches = list(map(list, batches)) + else: + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + _num_batches = len(indices) // self.batch_size + if _num_batches == 0: + batches = [indices] + else: + batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches))) + if len(indices)%self.batch_size!=0: + batches.append(indices[_num_batches*self.batch_size:]) + + need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas + if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: + if len(batches) > 0: + if len(batches[-1])self.rank: + if len(batches): + batches[-1].pop(-1) + if len(batches[-1])==0: + batches.pop(-1) + + assert sum(map(len, batches)) == self.num_left_samples + + if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: + batches = batches[:-1] + + for batch in batches: + self.num_consumed_samples += self.num_replicas * len(batch) + yield list(map(int, batch)) + self.during_iter = False + self.num_consumed_samples = 0 + self.old_batch_size = self.batch_size + self.old_num_replicas = self.num_replicas + if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 + self.epoch -= 1 + + def batchify(self, indices, batch_size, seed): + """ + 将 indices 分为 batches + + :param sorted_indices: List[int] + :param batch_size: int + :param seed: int + :return: List[List[int]] + """ + # 实际的 bucket 大小 + rng = np.random.default_rng(abs(seed)) + rng.shuffle(indices) + num_samples = 0 + batches = [] + while num_samplesint: + """ + 返回当前 sampler 还会返回多少个 batch 的数据 + + :return: + """ + num_sampler_per_rank = self.total_size//self.num_replicas + num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \ + (num_sampler_per_rank+self.batch_size-1)//self.batch_size + return num_batches + + def state_dict(self) -> Dict: + if self.old_batch_size != self.batch_size: + raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" + " consumed. ") + states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, + 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle, + 'batch_size': self.batch_size, + 'num_replicas': self.num_replicas} + + return states + + def load_state_dict(self, states: Dict): + # 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0; + assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ + "during an unfinished iteration." + + assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ + f"we cannot use {self.__class__.__name__} to load it." + + length = states['length'] + assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ + "and current dataset." + self.seed = states['seed'] + self.epoch = states['epoch'] + self.num_consumed_samples = states['num_consumed_samples'] + if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + self.num_consumed_samples = 0 + if self.shuffle != states['shuffle']: + logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " + f"we use shuffle={states['shuffle']}") + self.shuffle = states["shuffle"] + self.old_batch_size = states['batch_size'] + self.old_num_replicas = states['num_replicas'] + + class BucketedBatchSampler(ReproducibleBatchSampler): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index c8425dc7..de43d7df 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -16,6 +16,8 @@ from fastNLP.core.dataset import DataSet class ReproducibleSampler: """ + 可复现的 Sampler 对象。 + 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 @@ -54,13 +56,12 @@ class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): """ - :param dataset: 实现了 __len__ 方法的数据容器 :param shuffle: 是否在每次 iterate 的时候打乱顺序。 :param seed: 随机数种子。 :param kwargs: 用户不需要使用,fastNLP 内部使用 """ - + super(RandomSampler, self).__init__() self.dataset = dataset self.shuffle = shuffle self.seed = seed diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 910a2df0..9fb538a9 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -21,7 +21,6 @@ __all__ = [ 'nullcontext', 'pretty_table_printer', 'Option', - 'indice_collate_wrapper', 'deprecated', 'seq_len_to_mask', 'rank_zero_rm', @@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir + deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir +from ..dataloaders.utils import indice_collate_wrapper diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index 2856b656..42200cbb 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -1,5 +1,5 @@ import functools class DummyClass: - def __call__(self, *args, **kwargs): - return + def __init__(self, *args, **kwargs): + pass diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index e65cd735..e4c0a8a9 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -35,6 +35,7 @@ def paddle_to(data, device: Union[str, int]): else: return data.cuda(get_paddle_device_id(device)) + def get_paddle_gpu_str(device: Union[str, int]): """ 获得 `gpu:x` 类型的设备名 @@ -46,6 +47,7 @@ def get_paddle_gpu_str(device: Union[str, int]): return device.replace("cuda", "gpu") return f"gpu:{device}" + def get_paddle_device_id(device: Union[str, int]): """ 获得 gpu 的设备id @@ -94,18 +96,21 @@ def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) + def is_in_paddle_dist(): """ 判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 """ return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) + def is_in_fnlp_paddle_dist(): """ 判断是否处于 FastNLP 拉起的分布式进程中 """ return FASTNLP_DISTRIBUTED_CHECK in os.environ + def is_in_paddle_launch_dist(): """ 判断是否处于 launch 启动的分布式进程中 diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index c3f57bcf..91b3c8f6 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -6,7 +6,7 @@ import warnings from dataclasses import is_dataclass from copy import deepcopy from collections import defaultdict, OrderedDict -from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional +from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence from typing import Tuple, Optional from time import sleep @@ -35,7 +35,6 @@ __all__ = [ 'nullcontext', 'pretty_table_printer', 'Option', - 'indice_collate_wrapper', 'deprecated', 'seq_len_to_mask', 'rank_zero_rm', @@ -513,24 +512,6 @@ class Option(dict): self.update(state) -def indice_collate_wrapper(func): - """ - 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 - - :param func: 需要修饰的函数 - :return: - """ - - def wrapper(tuple_data): - indice, ins_list = [], [] - for idx, ins in tuple_data: - indice.append(idx) - ins_list.append(ins) - return indice, func(ins_list) - - return wrapper - - _emitted_deprecation_warnings = set() diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index 2796bb69..a14439ce 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -332,13 +332,44 @@ class DataBundle: show_progress_bar=show_progress_bar, progress_desc=progress_desc) return res - def set_pad_val(self, *field_names, val=0) -> None: + def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 + 无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: self + """ for _, ds in self.iter_datasets(): - ds.set_pad_val(*field_names, val=val) + ds.collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, backend=backend, + pad_fn=pad_fn) + return self - def set_input(self, *field_names) -> None: + def set_ignore(self, *field_names) -> "DataBundle": + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: self + """ for _, ds in self.iter_datasets(): - ds.set_input(*field_names) + ds.collator.set_ignore(*field_names) + return self def __repr__(self) -> str: _str = '' diff --git a/tests/core/callbacks/test_callback_event.py b/tests/core/callbacks/test_callback_event.py new file mode 100644 index 00000000..765c4432 --- /dev/null +++ b/tests/core/callbacks/test_callback_event.py @@ -0,0 +1,208 @@ +import pytest +from functools import reduce + +from fastNLP.core.callbacks.callback_event import Event, Filter + + + +class TestFilter: + def test_every_filter(self): + # every = 10 + @Filter(every=10) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [w-1 for w in range(10, 101, 10)] + + # every = 1 + @Filter(every=1) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == list(range(100)) + + def test_once_filter(self): + # once = 10 + @Filter(once=10) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [9] + + + def test_extract_filter_from_fn(self): + @Filter(every=10) + def _fn(data): + return data + + _filter_num_called = [] + _filter_num_executed = [] + for i in range(100): + cu_res = _fn(i) + _filter = _fn.__fastNLP_filter__ + _filter_num_called.append(_filter.num_called) + _filter_num_executed.append(_filter.num_executed) + assert _filter_num_called == list(range(1, 101)) + assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10] + + def _fn(data): + return data + assert not hasattr(_fn, "__fastNLP_filter__") + + def test_filter_state_dict(self): + # every = 10 + @Filter(every=10) + def _fn(data): + return data + + _res = [] + for i in range(50): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [w - 1 for w in range(10, 51, 10)] + + # 保存状态 + state = _fn.__fastNLP_filter__.state_dict() + # 加载状态 + _fn.__fastNLP_filter__.load_state_dict(state) + + _res = [] + for i in range(50, 100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [w - 1 for w in range(60, 101, 10)] + + +@pytest.mark.torch +def test_filter_fn_torch(): + from torch.optim import SGD + from torch.utils.data import DataLoader + from fastNLP.core.controllers.trainer import Trainer + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + optimizer = SGD(model.parameters(), lr=0.0001) + dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) + dataloader = DataLoader(dataset=dataset, batch_size=4) + + trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) + def filter_fn(filter, trainer): + if trainer.__heihei_test__ == 10: + return True + return False + + @Filter(filter_fn=filter_fn) + def _fn(trainer, data): + return data + + _res = [] + for i in range(100): + trainer.__heihei_test__ = i + cu_res = _fn(trainer, i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [10] + + +class TestCallbackEvents: + def test_every(self): + + # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试; + event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1; + @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == list(range(100)) + + event_state = Event.on_train_begin(every=10) + @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [w - 1 for w in range(10, 101, 10)] + + def test_once(self): + event_state = Event.on_train_begin(once=10) + + @Filter(once=event_state.once) + def _fn(data): + return data + + _res = [] + for i in range(100): + cu_res = _fn(i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [9] + + +@pytest.mark.torch +def test_callback_events_torch(): + from torch.optim import SGD + from torch.utils.data import DataLoader + from fastNLP.core.controllers.trainer import Trainer + from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 + from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification + + model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) + optimizer = SGD(model.parameters(), lr=0.0001) + dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) + dataloader = DataLoader(dataset=dataset, batch_size=4) + + trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) + def filter_fn(filter, trainer): + if trainer.__heihei_test__ == 10: + return True + return False + + event_state = Event.on_train_begin(filter_fn=filter_fn) + + @Filter(filter_fn=event_state.filter_fn) + def _fn(trainer, data): + return data + + _res = [] + for i in range(100): + trainer.__heihei_test__ = i + cu_res = _fn(trainer, i) + if cu_res is not None: + _res.append(cu_res) + assert _res == [10] + + + + + + + + + diff --git a/tests/core/callbacks/test_callback_events.py b/tests/core/callbacks/test_callback_events.py deleted file mode 100644 index 8712b469..00000000 --- a/tests/core/callbacks/test_callback_events.py +++ /dev/null @@ -1,157 +0,0 @@ -import pytest -from functools import reduce - -from fastNLP.core.callbacks.callback_events import Events, Filter - - -class TestFilter: - - def test_params_check(self): - # 顺利通过 - _filter1 = Filter(every=10) - _filter2 = Filter(once=10) - _filter3 = Filter(filter_fn=lambda: None) - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter4 = Filter() - exec_msg = e.value.args[0] - assert exec_msg == "If you mean your decorated function should be called every time, you do not need this filter." - - # 触发 ValueError - with pytest.raises(ValueError) as e: - _filter5 = Filter(every=10, once=10) - exec_msg = e.value.args[0] - assert exec_msg == "These three values should be only set one." - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter6 = Filter(every="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument every should be integer and greater than zero" - - # 触发 TypeError - with pytest.raises(ValueError) as e: - _filter7 = Filter(once="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument once should be integer and positive" - - # 触发 TypeError - with pytest.raises(TypeError) as e: - _filter7 = Filter(filter_fn="heihei") - exec_msg = e.value.args[0] - assert exec_msg == "Argument event_filter should be a callable" - - def test_every_filter(self): - # every = 10 - @Filter(every=10) - def _fn(data): - return data - - _res = [] - for i in range(100): - cu_res = _fn(i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [w-1 for w in range(10, 101, 10)] - - # every = 1 - @Filter(every=1) - def _fn(data): - return data - - _res = [] - for i in range(100): - cu_res = _fn(i) - if cu_res is not None: - _res.append(cu_res) - assert _res == list(range(100)) - - def test_once_filter(self): - # once = 10 - @Filter(once=10) - def _fn(data): - return data - - _res = [] - for i in range(100): - cu_res = _fn(i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [9] - - def test_filter_fn(self): - from torch.optim import SGD - from torch.utils.data import DataLoader - from fastNLP.core.controllers.trainer import Trainer - from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 - from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification - - model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) - optimizer = SGD(model.parameters(), lr=0.0001) - dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10) - dataloader = DataLoader(dataset=dataset, batch_size=4) - - trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer) - def filter_fn(filter, trainer): - if trainer.__heihei_test__ == 10: - return True - return False - - @Filter(filter_fn=filter_fn) - def _fn(trainer, data): - return data - - _res = [] - for i in range(100): - trainer.__heihei_test__ = i - cu_res = _fn(trainer, i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [10] - - def test_extract_filter_from_fn(self): - @Filter(every=10) - def _fn(data): - return data - - _filter_num_called = [] - _filter_num_executed = [] - for i in range(100): - cu_res = _fn(i) - _filter = _fn.__fastNLP_filter__ - _filter_num_called.append(_filter.num_called) - _filter_num_executed.append(_filter.num_executed) - assert _filter_num_called == list(range(1, 101)) - assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10] - - def _fn(data): - return data - assert not hasattr(_fn, "__fastNLP_filter__") - - def test_filter_state_dict(self): - # every = 10 - @Filter(every=10) - def _fn(data): - return data - - _res = [] - for i in range(50): - cu_res = _fn(i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [w - 1 for w in range(10, 51, 10)] - - # 保存状态 - state = _fn.__fastNLP_filter__.state_dict() - # 加载状态 - _fn.__fastNLP_filter__.load_state_dict(state) - - _res = [] - for i in range(50, 100): - cu_res = _fn(i) - if cu_res is not None: - _res.append(cu_res) - assert _res == [w - 1 for w in range(60, 101, 10)] - - diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 976b68ba..2de21825 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -2,9 +2,6 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist from pathlib import Path import re import time @@ -20,6 +17,11 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.log import logger +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: @@ -216,9 +218,9 @@ def test_model_checkpoint_callback_2( path = Path.cwd().joinpath("test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) - from fastNLP.core.callbacks.callback_events import Events + from fastNLP.core.callbacks.callback_event import Event - @Trainer.on(Events.on_train_epoch_end) + @Trainer.on(Event.on_train_epoch_end()) def raise_exception(trainer): if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: raise NotImplementedError @@ -550,7 +552,7 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -558,12 +560,13 @@ def test_trainer_checkpoint_callback_2( topk=None, last=False, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] elif version == 1: callbacks = [ - TrainerCheckpointCallback( + CheckpointCallback( monitor="acc", folder=path, every_n_epochs=None, @@ -571,7 +574,8 @@ def test_trainer_checkpoint_callback_2( topk=1, last=True, on_exception=None, - model_save_fn=model_save_fn + model_save_fn=model_save_fn, + save_object="trainer" ) ] diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 2b59ccd5..08c6f8e2 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -12,9 +12,7 @@ import os import pytest from typing import Any from dataclasses import dataclass -from torch.utils.data import DataLoader -from torch.optim import SGD -import torch.distributed as dist + from pathlib import Path import re @@ -29,7 +27,11 @@ from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.log import logger from fastNLP.core.callbacks import MoreEvaluateCallback - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist @dataclass class ArgMaxDatasetConfig: diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py index 4aa3d4de..a07a943e 100644 --- a/tests/core/collators/padders/test_get_padder.py +++ b/tests/core/collators/padders/test_get_padder.py @@ -17,12 +17,13 @@ def test_get_element_shape_dtype(): @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) @pytest.mark.torch @pytest.mark.paddle +@pytest.mark.jittor def test_get_padder_run(backend): if not _NEED_IMPORT_TORCH and backend == 'torch': pytest.skip("No torch") if not _NEED_IMPORT_PADDLE and backend == 'paddle': pytest.skip("No paddle") - if not _NEED_IMPORT_PADDLE and backend == 'jittor': + if not _NEED_IMPORT_JITTOR and backend == 'jittor': pytest.skip("No jittor") batch_field = [1, 2, 3] padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') @@ -66,6 +67,13 @@ def test_raw_padder(): pad_batch = padder(batch_field) assert np.shape(pad_batch) == (3, 3, 2) + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, list) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 + def test_numpy_padder(): backend = 'numpy' @@ -140,3 +148,18 @@ def test_torch_padder(): with pytest.raises(InconsistencyError): padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + # 可以是 numpy.ndarray + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 + + # 测试 to numpy + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, np.ndarray) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 diff --git a/tests/core/collators/padders/test_paddle_padder.py b/tests/core/collators/padders/test_paddle_padder.py index 80abf30a..bea10de0 100644 --- a/tests/core/collators/padders/test_paddle_padder.py +++ b/tests/core/collators/padders/test_paddle_padder.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from fastNLP.core.collators.padders.paddle_padder import paddleTensorPadder, paddleSequencePadder, paddleNumberPadder +from fastNLP.core.collators.padders.paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder from fastNLP.core.collators.padders.exceptions import DtypeError from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -10,9 +10,9 @@ if _NEED_IMPORT_PADDLE: @pytest.mark.paddle -class TestpaddleNumberPadder: +class TestPaddleNumberPadder: def test_run(self): - padder = paddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) a = [1, 2, 3] t_a = padder(a) assert isinstance(t_a, paddle.Tensor) @@ -20,9 +20,9 @@ class TestpaddleNumberPadder: @pytest.mark.paddle -class TestpaddleSequencePadder: +class TestPaddleSequencePadder: def test_run(self): - padder = paddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) a = [[1, 2, 3], [3]] a = padder(a) shape = a.shape @@ -32,20 +32,20 @@ class TestpaddleSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = paddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(3, dtype=np.int32).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.int32, dtype=None, pad_val=-1) a = padder([[1], [2, 322]]) # assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 - padder = paddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + padder = PaddleSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) @pytest.mark.paddle -class TestpaddleTensorPadder: +class TestPaddleTensorPadder: def test_run(self): - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3,)).dtype, dtype=paddle.zeros((3,)).dtype, pad_val=-1) a = [paddle.zeros((3,)), paddle.zeros((2,))] a = padder(a) shape = a.shape @@ -74,7 +74,7 @@ class TestpaddleTensorPadder: [[0, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, )).dtype, dtype=paddle.zeros((3, )).dtype, pad_val=-1) a = [paddle.zeros((3, 2)), paddle.zeros((2, 2))] a = padder(a) shape = a.shape @@ -85,7 +85,7 @@ class TestpaddleTensorPadder: ]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = paddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=paddle.zeros((3, 2)).dtype, dtype=None, pad_val=-1) a = [np.zeros((3, 2), dtype=np.float32), np.zeros((2, 2), dtype=np.float32)] a = padder(a) shape = a.shape @@ -96,11 +96,11 @@ class TestpaddleTensorPadder: assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): - padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) with pytest.raises(DtypeError): - padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) - padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) + padder = PaddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) def test_v1(self): print(paddle.zeros((3, )).dtype) diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py index 9742bc9a..9cb38766 100644 --- a/tests/core/collators/padders/test_raw_padder.py +++ b/tests/core/collators/padders/test_raw_padder.py @@ -23,7 +23,6 @@ class TestRawSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - with pytest.raises(DtypeError): - padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) + padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) \ No newline at end of file diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 2b56624a..ba1e7e08 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -1,81 +1,293 @@ + +import numpy as np import pytest -from fastNLP.core.collators import AutoCollator -from fastNLP.core.collators.collator import _MultiCollator -from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) class TestCollator: - @pytest.mark.parametrize('as_numpy', [True, False]) - def test_auto_collator(self, as_numpy): - """ - 测试auto_collator的auto_pad功能 - - :param as_numpy: - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=as_numpy) - collator.set_input('x', 'y') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - results = [] - for bucket in bucket_data: - res = collator(bucket) - assert res['x'].shape == (40, 5) - assert res['y'].shape == (40,) - results.append(res) - - def test_auto_collator_v1(self): - """ - 测试auto_collator的set_pad_val和set_pad_val功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - collator.set_input('x') - collator.set_pad_val('x', val=-1) - collator.set_as_numpy(True) - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = collator(bucket) - print(res) - - def test_multicollator(self): - """ - 测试multicollator功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - multi_collator = _MultiCollator(collator) - multi_collator.set_as_numpy(as_numpy=True) - multi_collator.set_pad_val('x', val=-1) - multi_collator.set_input('x') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = multi_collator(bucket) - print(res) + @pytest.mark.torch + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + def test_nest_ignore(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} + } + ] + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': {'int':[1, 1]}}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=None) + collator.set_ignore('str', 'int', 'lst_int') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [{'int':1}, {'int':1}]}} + pad_batch = collator(dict_batch) + findDictDiff(raw_pad_batch, pad_batch) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int') + collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) + pad_batch = collator(dict_batch) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [1, 1]}} + findDictDiff(raw_pad_batch, pad_batch) + + + + + + diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py deleted file mode 100644 index 87762c16..00000000 --- a/tests/core/collators/test_new_collator.py +++ /dev/null @@ -1,293 +0,0 @@ - -import numpy as np -import pytest - -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR - -from fastNLP.core.collators.new_collator import Collator - - -def _assert_equal(d1, d2): - try: - if 'torch' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) - assert (d1 == d2).all().item() - else: - assert all(d1 == d2) - except TypeError: - assert d1 == d2 - except ValueError: - assert (d1 == d2).all() - - -def findDictDiff(d1, d2, path=""): - for k in d1: - if k in d2: - if isinstance(d1[k], dict): - findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) - else: - _assert_equal(d1[k], d2[k]) - else: - raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) - - -def findListDiff(d1, d2): - assert len(d1)==len(d2) - for _d1, _d2 in zip(d1, d2): - if isinstance(_d1, list): - findListDiff(_d1, _d2) - else: - _assert_equal(_d1, _d2) - - -class TestCollator: - - @pytest.mark.torch - def test_run(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - collator = Collator(backend='raw') - assert raw_pad_batch == collator(dict_batch) - collator = Collator(backend='raw') - raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='numpy') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), - 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), - 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), - 'b': np.array([[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='numpy') - numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), - np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), - np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(numpy_pad_lst, collator(list_batch)) - - if _NEED_IMPORT_TORCH: - import torch - collator = Collator(backend='torch') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), - 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), - 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - 'float': torch.FloatTensor([1.1, 2.1]), - 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), - 'numpy': torch.FloatTensor([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), - 'b': torch.LongTensor( - [[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='torch') - torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), - torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), - torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(torch_pad_lst, collator(list_batch)) - - def test_pad(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 测试 set_pad - collator = Collator(backend='raw') - collator.set_pad('str', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - # 测试设置 pad 值 - collator = Collator(backend='raw') - collator.set_pad('nest_lst_int', pad_val=100) - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 设置 backend 和 type - collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - - # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - # [{'1'}, {'2'}]] - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_4', pad_val=None) - raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='raw') - collator.set_pad('_0', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_2', backend='numpy') - collator.set_pad('_4', backend='numpy', pad_val=100) - raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - # _single - collator = Collator() - collator.set_pad('_single') - findListDiff(list_batch, collator(list_batch)) - - def test_nest_ignore(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} - } - ] - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': {'int':[1, 1]}}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=None) - collator.set_ignore('str', 'int', 'lst_int') - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [{'int':1}, {'int':1}]}} - pad_batch = collator(dict_batch) - findDictDiff(raw_pad_batch, pad_batch) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int') - collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) - pad_batch = collator(dict_batch) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [1, 1]}} - findDictDiff(raw_pad_batch, pad_batch) - - - - - - diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index bcd89614..73eb0d6d 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -1,17 +1,20 @@ import pytest from typing import Any from dataclasses import dataclass -from torch.optim import SGD -from torch.utils.data import DataLoader -from torchmetrics import Accuracy -import torch.distributed as dist + from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks.callback_events import Events +from fastNLP.core.callbacks.callback_event import Event from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordTrainerEventTriggerCallback from tests.helpers.utils import magic_argv_env_context, Capturing +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + from torchmetrics import Accuracy + import torch.distributed as dist @dataclass @@ -62,12 +65,11 @@ def model_and_optimizers(): return trainer_params - +@pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) -@pytest.mark.torch @magic_argv_env_context -def test_trainer_event_trigger( +def test_trainer_event_trigger_1( model_and_optimizers: TrainerParameters, driver, device, @@ -97,8 +99,215 @@ def test_trainer_event_trigger( if dist.is_initialized(): dist.destroy_process_group() - for name, member in Events.__members__.items(): - assert member.value in output[0] + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] + +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) +@magic_argv_env_context +def test_trainer_event_trigger_2( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + + @Trainer.on(Event.on_after_trainer_initialized()) + def on_after_trainer_initialized(trainer, driver): + print("on_after_trainer_initialized") + + @Trainer.on(Event.on_sanity_check_begin()) + def on_sanity_check_begin(trainer): + print("on_sanity_check_begin") + + @Trainer.on(Event.on_sanity_check_end()) + def on_sanity_check_end(trainer, sanity_check_res): + print("on_sanity_check_end") + + @Trainer.on(Event.on_train_begin()) + def on_train_begin(trainer): + print("on_train_begin") + + @Trainer.on(Event.on_train_end()) + def on_train_end(trainer): + print("on_train_end") + + @Trainer.on(Event.on_train_epoch_begin()) + def on_train_epoch_begin(trainer): + if trainer.cur_epoch_idx >= 1: + # 触发 on_exception; + raise Exception + print("on_train_epoch_begin") + + @Trainer.on(Event.on_train_epoch_end()) + def on_train_epoch_end(trainer): + print("on_train_epoch_end") + + @Trainer.on(Event.on_fetch_data_begin()) + def on_fetch_data_begin(trainer): + print("on_fetch_data_begin") + + @Trainer.on(Event.on_fetch_data_end()) + def on_fetch_data_end(trainer): + print("on_fetch_data_end") + + @Trainer.on(Event.on_train_batch_begin()) + def on_train_batch_begin(trainer, batch, indices=None): + print("on_train_batch_begin") + + @Trainer.on(Event.on_train_batch_end()) + def on_train_batch_end(trainer): + print("on_train_batch_end") + + @Trainer.on(Event.on_exception()) + def on_exception(trainer, exception): + print("on_exception") + + @Trainer.on(Event.on_before_backward()) + def on_before_backward(trainer, outputs): + print("on_before_backward") + + @Trainer.on(Event.on_after_backward()) + def on_after_backward(trainer): + print("on_after_backward") + + @Trainer.on(Event.on_before_optimizers_step()) + def on_before_optimizers_step(trainer, optimizers): + print("on_before_optimizers_step") + + @Trainer.on(Event.on_after_optimizers_step()) + def on_after_optimizers_step(trainer, optimizers): + print("on_after_optimizers_step") + + @Trainer.on(Event.on_before_zero_grad()) + def on_before_zero_grad(trainer, optimizers): + print("on_before_zero_grad") + + @Trainer.on(Event.on_after_zero_grad()) + def on_after_zero_grad(trainer, optimizers): + print("on_after_zero_grad") + + @Trainer.on(Event.on_evaluate_begin()) + def on_evaluate_begin(trainer): + print("on_evaluate_begin") + + @Trainer.on(Event.on_evaluate_end()) + def on_evaluate_end(trainer, results): + print("on_evaluate_end") + + with pytest.raises(Exception): + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + Event_attrs = Event.__dict__ + for k, v in Event_attrs.items(): + if isinstance(v, staticmethod): + assert k in output[0] + + +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 6)]) +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_event_trigger_3( + model_and_optimizers: TrainerParameters, + driver, + device, + n_epochs=2, +): + import re + + once_message_1 = "This message should be typed 1 times." + once_message_2 = "test_filter_fn" + once_message_3 = "once message 3" + twice_message = "twice message hei hei" + + @Trainer.on(Event.on_train_epoch_begin(every=2)) + def train_epoch_begin_1(trainer): + print(once_message_1) + + @Trainer.on(Event.on_train_epoch_begin()) + def train_epoch_begin_2(trainer): + print(twice_message) + + @Trainer.on(Event.on_train_epoch_begin(once=2)) + def train_epoch_begin_3(trainer): + print(once_message_3) + + def filter_fn(filter, trainer): + if trainer.cur_epoch_idx == 1: + return True + else: + return False + + @Trainer.on(Event.on_train_epoch_end(filter_fn=filter_fn)) + def test_filter_fn(trainer): + print(once_message_2) + + with Capturing() as output: + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=n_epochs, + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + + once_pattern_1 = re.compile(once_message_1) + once_pattern_2 = re.compile(once_message_2) + once_pattern_3 = re.compile(once_message_3) + twice_pattern = re.compile(twice_message) + + once_res_1 = once_pattern_1.findall(output[0]) + assert len(once_res_1) == 1 + once_res_2 = once_pattern_2.findall(output[0]) + assert len(once_res_2) == 1 + once_res_3 = once_pattern_3.findall(output[0]) + assert len(once_res_3) == 1 + twice_res = twice_pattern.findall(output[0]) + assert len(twice_res) == 2 + + + + + + + + + + + + diff --git a/tests/core/controllers/test_trainer_other_things.py b/tests/core/controllers/test_trainer_other_things.py index 9cdec2dd..3d9a5037 100644 --- a/tests/core/controllers/test_trainer_other_things.py +++ b/tests/core/controllers/test_trainer_other_things.py @@ -1,22 +1,22 @@ import pytest from fastNLP.core.controllers.trainer import Trainer -from fastNLP.core.callbacks import Events +from fastNLP.core.callbacks import Event from tests.helpers.utils import magic_argv_env_context @magic_argv_env_context def test_trainer_torch_without_evaluator(): - @Trainer.on(Events.on_train_epoch_begin(every=10)) + @Trainer.on(Event.on_train_epoch_begin(every=10), marker="test_trainer_other_things") def fn1(trainer): pass - @Trainer.on(Events.on_train_batch_begin(every=10)) + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn2(trainer, batch, indices): pass - with pytest.raises(AssertionError): - @Trainer.on(Events.on_train_batch_begin(every=10)) + with pytest.raises(BaseException): + @Trainer.on(Event.on_train_batch_begin(every=10), marker="test_trainer_other_things") def fn3(trainer, batch): pass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 891626b5..f44bd735 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -2,9 +2,7 @@ 注意这一文件中的测试函数都应当是在 `test_trainer_w_evaluator_torch.py` 中已经测试过的测试函数的基础上加上 metrics 和 evaluator 修改而成; """ import pytest -from torch.optim import SGD -from torch.utils.data import DataLoader -import torch.distributed as dist + from dataclasses import dataclass from typing import Any from torchmetrics import Accuracy @@ -14,7 +12,11 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.optim import SGD + from torch.utils.data import DataLoader + import torch.distributed as dist @dataclass class NormalClassificationTrainTorchConfig: diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 825bd425..102ab310 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -2,9 +2,7 @@ import os.path import subprocess import sys import pytest -import torch.distributed as dist -from torch.optim import SGD -from torch.utils.data import DataLoader + from dataclasses import dataclass from typing import Any from pathlib import Path @@ -16,6 +14,11 @@ from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch.distributed as dist + from torch.optim import SGD + from torch.utils.data import DataLoader @dataclass @@ -257,9 +260,9 @@ def test_trainer_on_exception( cur_rank, n_epochs=2, ): - from fastNLP.core.callbacks.callback_events import Events + from fastNLP.core.callbacks.callback_event import Event - @Trainer.on(Events.on_train_epoch_end) + @Trainer.on(Event.on_train_epoch_end()) def raise_exception(trainer): if trainer.driver.get_local_rank() == cur_rank: raise NotImplementedError @@ -286,6 +289,7 @@ def test_trainer_on_exception( dist.destroy_process_group() +@pytest.mark.torch @pytest.mark.parametrize("version", [0, 1, 2, 3]) @magic_argv_env_context def test_torch_distributed_launch_1(version): diff --git a/tests/core/controllers/utils/test_utils.py b/tests/core/controllers/utils/test_utils.py index 0cf7a252..39c1987a 100644 --- a/tests/core/controllers/utils/test_utils.py +++ b/tests/core/controllers/utils/test_utils.py @@ -1,7 +1,7 @@ from functools import reduce from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; -from tests.helpers.datasets.normal_data import NormalIterator +from tests.helpers.datasets.normal_data import NormalSampler class Test_WrapDataLoader: @@ -9,9 +9,9 @@ class Test_WrapDataLoader: def test_normal_generator(self): all_sanity_batches = [4, 20, 100] for sanity_batches in all_sanity_batches: - data = NormalIterator(num_of_data=1000) + data = NormalSampler(num_of_data=1000) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) - dataloader = iter(wrapper(dataloader=data)) + dataloader = iter(wrapper) mark = 0 while True: try: @@ -32,8 +32,7 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - dataloader = iter(dataloader) + dataloader = iter(wrapper) all_supposed_running_data_num = 0 while True: try: @@ -55,6 +54,5 @@ class Test_WrapDataLoader: dataset = TorchNormalDataset(num_of_data=1000) dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) - dataloader = wrapper(dataloader) - length.append(len(dataloader)) + length.append(len(wrapper)) assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) \ No newline at end of file diff --git a/tests/core/drivers/jittor_driver/test_single_device.py b/tests/core/drivers/jittor_driver/test_single_device.py index 8bbceed9..2e220974 100644 --- a/tests/core/drivers/jittor_driver/test_single_device.py +++ b/tests/core/drivers/jittor_driver/test_single_device.py @@ -15,7 +15,7 @@ else: -class Model (Module): +class Model(Module): def __init__ (self): super (Model, self).__init__() self.conv1 = nn.Conv (3, 32, 3, 1) # no padding @@ -45,6 +45,7 @@ class Model (Module): return x @pytest.mark.jittor +@pytest.mark.skip("Skip jittor tests now.") class TestSingleDevice: def test_on_gpu_without_fp16(self): diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index a00a41f5..b8ccd802 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -278,7 +278,7 @@ class TestPaddleDriverFunctions: dataset = PaddleNormalDataset() dataloader = DataLoader( dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), batch_size, drop_last, @@ -287,7 +287,7 @@ class TestPaddleDriverFunctions: res = PaddleSingleDriver.get_dataloader_args(dataloader) assert isinstance(res.dataset, PaddleNormalDataset) - assert isinstance(res.batch_sampler, RandomBatchSampler) + assert isinstance(res.batch_sampler, ReproduceBatchSampler) if shuffle: assert isinstance(res.sampler, paddle.io.RandomSampler) else: @@ -387,7 +387,7 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) @@ -400,7 +400,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) else: # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -414,11 +414,11 @@ class TestSetDistReproDataloader: 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) + dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler is dist self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @@ -450,7 +450,7 @@ class TestSetDistReproDataloader: """ dataloader = DataLoader( dataset=self.dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), batch_size=4, drop_last=False, @@ -459,7 +459,7 @@ class TestSetDistReproDataloader: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -500,20 +500,20 @@ class TestSetDistReproDataloader: if idx >= num_consumed_batches: break already_seen_idx.update(batch) - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader new_loader = DataLoader( dataset=replaced_loader.dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), batch_size=batch_size, drop_last=False, @@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): dataset = PaddleRandomMaxDataset(40, 10) dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) + batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False) ) driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") @@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 更改 batch_size dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) + batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) ) load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") @@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 diff --git a/tests/core/drivers/paddle_driver/test_utils.py b/tests/core/drivers/paddle_driver/test_utils.py index 4b683c1e..3b0fb9e0 100644 --- a/tests/core/drivers/paddle_driver/test_utils.py +++ b/tests/core/drivers/paddle_driver/test_utils.py @@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import ( replace_batch_sampler, replace_sampler, ) -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle @@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, def test_replace_batch_sampler(): dataset = PaddleNormalDataset(10) dataloader = DataLoader(dataset, batch_size=32) - batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) replaced_loader = replace_batch_sampler(dataloader, batch_sampler) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.dataset, PaddleNormalDataset) assert len(replaced_loader.dataset) == len(dataset) assert replaced_loader.batch_sampler.batch_size == 16 diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 48299bf4..d6f0ee77 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -13,12 +13,13 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.utils.data import DataLoader, BatchSampler -import torch -import torch.distributed as dist -from torch.utils.data import DataLoader, BatchSampler - -def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="all"): torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) device = [torch.device(i) for i in device] @@ -72,108 +73,100 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= # ############################################################################ +@pytest.mark.torch +@magic_argv_env_context +def test_multi_drivers(): + """ + 测试使用了多个 TorchDDPDriver 的情况。 + """ + generate_driver(10, 10) + generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + generate_driver(20, 3, device=[0,1,2]) + assert False + dist.barrier() + + if dist.is_initialized(): + dist.destroy_process_group() + @pytest.mark.torch class TestDDPDriverFunction: """ 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 """ - @classmethod - def setup_class(cls): - cls.driver = generate_driver(10, 10) - @magic_argv_env_context - def test_multi_drivers(self): + def test_simple_functions(self): """ - 测试使用了多个 TorchDDPDriver 的情况。 + 简单测试多个函数 """ - - driver2 = generate_driver(20, 10) - - with pytest.raises(RuntimeError): - # 设备设置不同,应该报错 - driver3 = generate_driver(20, 3, device=[0,1,2]) - assert False - dist.barrier() + driver = generate_driver(10, 10) - @magic_argv_env_context - def test_move_data_to_device(self): """ - 这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 - 就不重复测试了 + 测试 move_data_to_device 函数。这个函数仅调用了 torch_move_data_to_device ,测试例在 + tests/core/utils/test_torch_utils.py中,就不重复测试了 """ - self.driver.move_data_to_device(torch.rand((32, 64))) - + driver.move_data_to_device(torch.rand((32, 64))) dist.barrier() - @magic_argv_env_context - def test_is_distributed(self): """ 测试 is_distributed 函数 """ - assert self.driver.is_distributed() == True + assert driver.is_distributed() == True dist.barrier() - @magic_argv_env_context - def test_get_no_sync_context(self): """ 测试 get_no_sync_context 函数 """ - res = self.driver.get_model_no_sync_context() + res = driver.get_model_no_sync_context() dist.barrier() - @magic_argv_env_context - def test_is_global_zero(self): """ 测试 is_global_zero 函数 """ - self.driver.is_global_zero() + driver.is_global_zero() dist.barrier() - @magic_argv_env_context - def test_unwrap_model(self): """ 测试 unwrap_model 函数 """ - self.driver.unwrap_model() + driver.unwrap_model() dist.barrier() - @magic_argv_env_context - def test_get_local_rank(self): """ 测试 get_local_rank 函数 """ - self.driver.get_local_rank() + driver.get_local_rank() dist.barrier() - @magic_argv_env_context - def test_all_gather(self): """ 测试 all_gather 函数 详细的测试在 test_dist_utils.py 中完成 """ obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } - obj_list = self.driver.all_gather(obj, group=None) + obj_list = driver.all_gather(obj, group=None) for i, res in enumerate(obj_list): assert res["rank"] == i - @magic_argv_env_context - @pytest.mark.parametrize("src_rank", ([0, 1])) - def test_broadcast_object(self, src_rank): """ 测试 broadcast_object 函数 详细的函数在 test_dist_utils.py 中完成 """ - if self.driver.global_rank == src_rank: + if driver.global_rank == 0: obj = { - "rank": self.driver.global_rank + "rank": driver.global_rank } else: obj = None - res = self.driver.broadcast_object(obj, src=src_rank) - assert res["rank"] == src_rank + res = driver.broadcast_object(obj, src=0) + assert res["rank"] == 0 + + if dist.is_initialized(): + dist.destroy_process_group() ############################################################################ # @@ -187,7 +180,6 @@ class TestSetDistReproDataloader: @classmethod def setup_class(cls): cls.device = [0, 1] - cls.driver = generate_driver(10, 10, device=cls.device) def setup_method(self): self.dataset = TorchNormalDataset(40) @@ -204,17 +196,20 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler is batch_sampler self.check_distributed_sampler(replaced_loader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -223,9 +218,10 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) sampler = RandomSampler(self.dataset, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, sampler, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -234,9 +230,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler is sampler assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` @@ -251,15 +249,17 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) with pytest.raises(RuntimeError): # 应当抛出 RuntimeError - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, True) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context - # @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False])) def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): """ @@ -268,21 +268,24 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) assert replaced_loader.batch_sampler.batch_size == 4 self.check_distributed_sampler(dataloader.batch_sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -292,12 +295,13 @@ class TestSetDistReproDataloader: 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -307,9 +311,11 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.batch_size == 4 assert replaced_loader.batch_sampler.drop_last == False self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) - self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + self.check_set_dist_repro_dataloader(driver, dataloader, replaced_loader, shuffle) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -318,11 +324,14 @@ class TestSetDistReproDataloader: 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 此时直接返回原来的 dataloader,不做任何处理。 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, None, False) assert replaced_loader is dataloader dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -337,12 +346,13 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader( dataset=self.dataset, batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) ) dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) @@ -351,6 +361,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -361,8 +373,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) @@ -372,6 +385,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -381,8 +396,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "dist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -392,6 +408,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() """ 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 @@ -407,8 +425,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -418,6 +437,8 @@ class TestSetDistReproDataloader: assert replaced_loader.batch_sampler.sampler.shuffle == shuffle self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -427,8 +448,9 @@ class TestSetDistReproDataloader: 的表现 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler """ + driver = generate_driver(10, 10, device=self.device) dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -439,6 +461,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() @magic_argv_env_context @pytest.mark.parametrize("shuffle", ([True, False])) @@ -448,8 +472,9 @@ class TestSetDistReproDataloader: 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 的属性 """ + driver = generate_driver(10, 10, device=self.device) dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) - replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + replaced_loader = driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) assert not (replaced_loader is dataloader) assert isinstance(replaced_loader.batch_sampler, BatchSampler) @@ -459,6 +484,8 @@ class TestSetDistReproDataloader: assert replaced_loader.drop_last == dataloader.drop_last self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) dist.barrier() + if dist.is_initialized(): + dist.destroy_process_group() def check_distributed_sampler(self, sampler): """ @@ -469,7 +496,7 @@ class TestSetDistReproDataloader: if not isinstance(sampler, UnrepeatedSampler): assert sampler.pad == True - def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + def check_set_dist_repro_dataloader(self, driver, dataloader, replaced_loader, shuffle): """ 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 """ @@ -501,8 +528,8 @@ class TestSetDistReproDataloader: drop_last=False, ) new_loader.batch_sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank, + num_replicas=driver.world_size, + rank=driver.global_rank, pad=True ) new_loader.batch_sampler.load_state_dict(sampler_states) @@ -512,8 +539,8 @@ class TestSetDistReproDataloader: # 重新构造 dataloader new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) new_loader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver.world_size, - rank=self.driver.global_rank + num_replicas=driver.world_size, + rank=driver.global_rank ) new_loader.batch_sampler.sampler.load_state_dict(sampler_states) for idx, batch in enumerate(new_loader): @@ -534,11 +561,6 @@ class TestSaveLoad: 测试多卡情况下 save 和 load 相关函数的表现 """ - @classmethod - def setup_class(cls): - # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) - def setup_method(self): self.dataset = TorchArgMaxDataset(10, 20) @@ -552,26 +574,26 @@ class TestSaveLoad: path = "model" dataloader = DataLoader(self.dataset, batch_size=2) - self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + driver1, driver2 = generate_driver(10, 10), generate_driver(10, 10) - self.driver1.save_model(path, only_state_dict) + driver1.save_model(path, only_state_dict) # 同步 dist.barrier() - self.driver2.load_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) for idx, batch in enumerate(dataloader): - batch = self.driver1.move_data_to_device(batch) - res1 = self.driver1.model( + batch = driver1.move_data_to_device(batch) + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -580,6 +602,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -593,7 +618,7 @@ class TestSaveLoad: path = "model.ckp" num_replicas = len(device) - self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + driver1, driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_bucketedbatchsampler( self.dataset, @@ -603,8 +628,8 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -623,7 +648,7 @@ class TestSaveLoad: # 保存状态 sampler_states = dataloader.batch_sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) # 加载 # 更改 batch_size dataloader = dataloader_with_bucketedbatchsampler( @@ -634,11 +659,11 @@ class TestSaveLoad: drop_last=False ) dataloader.batch_sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 # TODO optimizer 的 state_dict 总是为空 @@ -652,7 +677,7 @@ class TestSaveLoad: # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -664,16 +689,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -686,6 +711,9 @@ class TestSaveLoad: finally: rank_zero_rm(path) + if dist.is_initialized(): + dist.destroy_process_group() + @magic_argv_env_context @pytest.mark.parametrize("only_state_dict", ([True, False])) @pytest.mark.parametrize("fp16", ([True, False])) @@ -700,13 +728,13 @@ class TestSaveLoad: num_replicas = len(device) - self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) - self.driver2 = generate_driver(10, 10, device=device, fp16=False) + driver1 = generate_driver(10, 10, device=device, fp16=fp16) + driver2 = generate_driver(10, 10, device=device, fp16=False) dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver1.world_size, - rank=self.driver1.global_rank, + num_replicas=driver1.world_size, + rank=driver1.global_rank, pad=True ) num_consumed_batches = 2 @@ -726,18 +754,18 @@ class TestSaveLoad: sampler_states = dataloader.batch_sampler.sampler.state_dict() save_states = {"num_consumed_batches": num_consumed_batches} if only_state_dict: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) else: - self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) # 加载 # 更改 batch_size dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) dataloader.batch_sampler.sampler.set_distributed( - num_replicas=self.driver2.world_size, - rank=self.driver2.global_rank, + num_replicas=driver2.world_size, + rank=driver2.global_rank, pad=True ) - load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) replaced_loader = load_states.pop("dataloader") # 1. 检查 optimizer 的状态 @@ -753,7 +781,7 @@ class TestSaveLoad: assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] # 3. 检查 fp16 是否被加载 if fp16: - assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) # 4. 检查 model 的参数是否正确 # 5. 检查 batch_idx @@ -765,16 +793,16 @@ class TestSaveLoad: left_x_batches.update(batch["x"]) left_y_batches.update(batch["y"]) - res1 = self.driver1.model( + res1 = driver1.model( batch, - fastnlp_fn=self.driver1.model.module.model.evaluate_step, + fastnlp_fn=driver1.model.module.model.evaluate_step, # Driver.model -> DataParallel.module -> _FleetWrappingModel.model fastnlp_signature_fn=None, wo_auto_param_call=False, ) - res2 = self.driver2.model( + res2 = driver2.model( batch, - fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_fn=driver2.model.module.model.evaluate_step, fastnlp_signature_fn=None, wo_auto_param_call=False, ) @@ -786,4 +814,7 @@ class TestSaveLoad: assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas finally: - rank_zero_rm(path) \ No newline at end of file + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py index f62ccd0c..8ec70de1 100644 --- a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -2,12 +2,14 @@ import pytest from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver -from fastNLP.envs import get_gpu_count from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.utils import magic_argv_env_context - -import torch - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import device as torchdevice +else: + from fastNLP.core.utils.dummy_class import DummyClass as torchdevice @pytest.mark.torch def test_incorrect_driver(): @@ -20,7 +22,7 @@ def test_incorrect_driver(): @pytest.mark.torch @pytest.mark.parametrize( "device", - ["cpu", "cuda:0", 0, torch.device("cuda:0")] + ["cpu", "cuda:0", 0, torchdevice("cuda:0")] ) @pytest.mark.parametrize( "driver", @@ -83,7 +85,6 @@ def test_get_ddp(driver, device): ("driver", "device"), [("torch_ddp", "cpu")] ) -@magic_argv_env_context def test_get_ddp_cpu(driver, device): """ 测试试图在 cpu 上初始化分布式训练的情况 @@ -96,13 +97,12 @@ def test_get_ddp_cpu(driver, device): @pytest.mark.torch @pytest.mark.parametrize( "device", - [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] + [-2, [0, 20, 3], [-2], 20] ) @pytest.mark.parametrize( "driver", ["torch", "torch_ddp"] ) -@magic_argv_env_context def test_device_out_of_range(driver, device): """ 测试传入的device超过范围的情况 diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 8c761a95..ef60e2b6 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -2,7 +2,7 @@ import pytest from pathlib import Path from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset @@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE: def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): """ - 建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader + 建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader """ if shuffle: sampler = torch.utils.data.RandomSampler(dataset) @@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): sampler = torch.utils.data.SequentialSampler(dataset) dataloader = DataLoader( dataset=dataset, - batch_sampler=RandomBatchSampler( + batch_sampler=ReproduceBatchSampler( BatchSampler( sampler, batch_size=batch_size, drop_last=drop_last ), @@ -306,7 +306,7 @@ class TestTorchDriverFunctions: res = TorchSingleDriver.get_dataloader_args(dataloader) assert isinstance(res.dataset, TorchNormalDataset) - assert isinstance(res.batch_sampler, RandomBatchSampler) + assert isinstance(res.batch_sampler, ReproduceBatchSampler) if shuffle: assert isinstance(res.sampler, torch.utils.data.RandomSampler) else: @@ -401,7 +401,7 @@ class TestSetDistReproDataloader: """ 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), - 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) @@ -414,7 +414,7 @@ class TestSetDistReproDataloader: assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) else: # 此时会替换 batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -428,11 +428,11 @@ class TestSetDistReproDataloader: 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler """ dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) - dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) + dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler is dist self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) @@ -466,7 +466,7 @@ class TestSetDistReproDataloader: replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.drop_last == dataloader.drop_last @@ -502,14 +502,14 @@ class TestSetDistReproDataloader: if idx >= num_consumed_batches: break already_seen_idx.update(batch) - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): sampler_states = replaced_loader.batch_sampler.state_dict() else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range left_idxes = set() - if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler): batch_size = replaced_loader.batch_sampler.batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新改造 dataloader @@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): # 2. 检查 batch_sampler 是否被正确地加载和替换 assert not (replaced_loader is dataloader) assert replaced_loader.batch_sampler is dataloader.batch_sampler - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index 161bbfe8..56de18fe 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -30,7 +30,7 @@ class SequenceDataSet: def check_replace_sampler(driver): - # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler + # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler # reproducible 是 True 和 False # 需要 check 返回的 sampler 和 dataloader 都不同了 diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index 97037b71..8d5d3267 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( replace_batch_sampler, replace_sampler, ) -from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler from torch.utils.data import DataLoader, BatchSampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset def test_replace_batch_sampler(): dataset = TorchNormalDataset(10) dataloader = DataLoader(dataset, batch_size=32) - batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) replaced_loader = replace_batch_sampler(dataloader, batch_sampler) assert not (replaced_loader is dataloader) - assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler) assert isinstance(replaced_loader.dataset, TorchNormalDataset) assert len(replaced_loader.dataset) == len(dataset) assert replaced_loader.batch_sampler.batch_size == 16 diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index b89d15db..cadf4e0e 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -7,15 +7,20 @@ import copy import socket import pytest import numpy as np -import torch -import torch.distributed -from torch.multiprocessing import Pool, set_start_method + from sklearn.metrics import accuracy_score as sklearn_accuracy from fastNLP.core.dataset import DataSet from fastNLP.core.metrics.accuracy import Accuracy from fastNLP.core.metrics.metric import Metric from .utils import find_free_network_port, setup_ddp, _assert_allclose +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -26,7 +31,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class: Type[Metric], metric_kwargs: Dict[str, Any], diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py index bc006cb1..75203a3e 100644 --- a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -2,18 +2,23 @@ from functools import partial import copy import pytest -import torch + import numpy as np -from torch.multiprocessing import Pool, set_start_method from fastNLP.core.metrics import ClassifyFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) -def _test(local_rank: int, world_size: int, device: torch.device, +def _test(local_rank: int, world_size: int, device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, metric_result): metric = metric_class(**metric_kwargs) # dataset 也类似(每个进程有自己的一个) diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index 72db05fc..0ebb9bdd 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -5,16 +5,21 @@ import os, sys import copy from functools import partial -import torch -import torch.distributed import numpy as np import socket -from torch.multiprocessing import Pool, set_start_method + # from multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH from .utils import find_free_network_port, setup_ddp +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed + from torch.multiprocessing import Pool, set_start_method +else: + from fastNLP.core.utils.dummy_class import DummyClass as set_start_method set_start_method("spawn", force=True) @@ -44,7 +49,7 @@ pool = None def _test(local_rank: int, world_size: int, - device: torch.device, + device: "torch.device", dataset: DataSet, metric_class, metric_kwargs, diff --git a/tests/core/metrics/utils.py b/tests/core/metrics/utils.py index 10157438..4126dc97 100644 --- a/tests/core/metrics/utils.py +++ b/tests/core/metrics/utils.py @@ -2,9 +2,11 @@ import os, sys import socket from typing import Union -import torch -from torch import distributed import numpy as np +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch import distributed def setup_ddp(rank: int, world_size: int, master_port: int) -> None: diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 6cf4b7d4..c4dd8c50 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -1,161 +1,131 @@ -from array import array - import numpy as np import pytest from itertools import chain from copy import deepcopy +from array import array + +from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler +from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler + + +class TestReproducibleBatchSampler: + def test_1(self): + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) + + forward_steps = 3 + iterator = iter(reproduce_batch_sampler) + i = 0 + while i < forward_steps: + next(iterator) + i += 1 + + # 保存状态; + state = reproduce_batch_sampler.state_dict() + + assert state == {"index_list": array("I", list(range(100))), + "num_consumed_samples": forward_steps * 4, + "sampler_type": "ReproduceBatchSampler"} + + # 重新生成一个 batchsampler 然后加载状态; + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + real_res = [] + supposed_res = (list(range(12, 16)), list(range(16, 20))) + forward_steps = 2 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert supposed_res[i] == real_res[i] + + # 改变 batchsize; + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + real_res = [] + supposed_res = (list(range(12, 19)), list(range(19, 26))) + forward_steps = 2 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert supposed_res[i] == real_res[i] + + # 断点重训的第二轮是否是一个完整的 dataloader; + # 先把断点重训所在的那一个 epoch 跑完; + begin_idx = 26 + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert data == list(range(begin_idx, begin_idx + _batch_size)) + begin_idx += _batch_size + except StopIteration: + break + + # 开始新的一轮; + begin_idx = 0 + iter_dataloader = iter(reproduce_batch_sampler) + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert data == list(range(begin_idx, begin_idx + _batch_size)) + begin_idx += _batch_size + except StopIteration: + break -from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler -from tests.helpers.datasets.torch_data import TorchNormalDataset - -# -# class TestReproducibleBatchSampler: -# # TODO 拆分测试,在这里只测试一个东西 -# def test_torch_dataloader_1(self): -# import torch -# from torch.utils.data import DataLoader -# # no shuffle -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# forward_steps = 3 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# next(iter_dataloader) -# -# # 1. 保存状态 -# _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, RandomBatchSampler) -# state = _get_re_batchsampler.state_dict() -# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, -# "sampler_type": "RandomBatchSampler"} -# -# # 2. 断点重训,重新生成一个 dataloader; -# # 不改变 batch_size; -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# real_res = [] -# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) -# forward_steps = 2 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# real_res.append(next(iter_dataloader)) -# -# for i in range(forward_steps): -# assert all(real_res[i] == supposed_res[i]) -# -# # 改变 batch_size; -# after_batch_size = 3 -# dataloader = DataLoader(dataset, batch_size=after_batch_size) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# real_res = [] -# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) -# forward_steps = 2 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# real_res.append(next(iter_dataloader)) -# -# for i in range(forward_steps): -# assert all(real_res[i] == supposed_res[i]) -# -# # 断点重训的第二轮是否是一个完整的 dataloader; -# # 先把断点重训所在的那一个 epoch 跑完; -# begin_idx = 27 -# while True: -# try: -# data = next(iter_dataloader) -# _batch_size = len(data) -# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) -# begin_idx += _batch_size -# except StopIteration: -# break -# -# # 开始新的一轮; -# begin_idx = 0 -# iter_dataloader = iter(dataloader) -# while True: -# try: -# data = next(iter_dataloader) -# _batch_size = len(data) -# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) -# begin_idx += _batch_size -# except StopIteration: -# break -# -# def test_torch_dataloader_2(self): -# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; -# from torch.utils.data import DataLoader -# # no shuffle -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; -# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# # 将一轮的所有数据保存下来,看是否恢复的是正确的; -# all_supposed_data = [] -# forward_steps = 3 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# all_supposed_data.extend(next(iter_dataloader).tolist()) -# -# # 1. 保存状态 -# _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, RandomBatchSampler) -# state = _get_re_batchsampler.state_dict() -# -# # 2. 断点重训,重新生成一个 dataloader; -# # 不改变 batch_size; -# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# # 先把这一轮的数据过完; -# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] -# while True: -# try: -# all_supposed_data.extend(next(iter_dataloader).tolist()) -# except StopIteration: -# break -# assert all_supposed_data == list(pre_index_list) -# -# # 重新开启新的一轮; -# for _ in range(3): -# iter_dataloader = iter(dataloader) -# res = [] -# while True: -# try: -# res.append(next(iter_dataloader)) -# except StopIteration: -# break -# -# def test_3(self): -# import torch -# from torch.utils.data import DataLoader -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# -# for idx, data in enumerate(dataloader): -# if idx > 3: -# break -# -# iterator = iter(dataloader) -# for each in iterator: -# pass + def test_2(self): + + # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; + before_batch_size = 7 + sampler = NormalSampler(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) + + # 将一轮的所有数据保存下来,看是否恢复的是正确的; + all_supposed_data = [] + forward_steps = 3 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + all_supposed_data.extend(next(iter_dataloader)) + + # 1. 保存状态 + state = reproduce_batch_sampler.state_dict() + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + sampler = NormalSampler(num_of_data=100, shuffle=True) + reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + # 先把这一轮的数据过完; + pre_index_list = reproduce_batch_sampler.state_dict()["index_list"] + iter_dataloader = iter(reproduce_batch_sampler) + while True: + try: + all_supposed_data.extend(next(iter_dataloader)) + except StopIteration: + break + assert all_supposed_data == list(pre_index_list) + + # 重新开启新的一轮; + for _ in range(3): + iter_dataloader = iter(reproduce_batch_sampler) + res = [] + while True: + try: + res.extend(next(iter_dataloader)) + except StopIteration: + break + assert res != all_supposed_data class DatasetWithVaryLength: @@ -511,3 +481,313 @@ class TestBucketedBatchSampler: already_seen_set.update(batch) assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) + + +class TestRandomBatchSampler: + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) + def test_single_num_batch(self, shuffle, drop_last, num): + # 数量不够不报错 + for num in [2, 7, 14, 15, 70, 71]: + dataset = DatasetWithVaryLength(num_of_data=num) + before_batch_size = 7 + re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + count = len(list(iter(re_batchsampler))) + if drop_last: + assert count==num//before_batch_size, num + else: + assert count==(num+before_batch_size-1)//before_batch_size, num + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + def test_single(self, shuffle, drop_last): + + before_batch_size = 7 + num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 + + dataset = DatasetWithVaryLength(num_of_data=1000) + re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler.set_epoch(0) + forward_steps = 10 + iterator = iter(re_batchsampler) + already_generate_indices = set() + for _ in range(forward_steps): + batch = next(iterator) + already_generate_indices.update(batch) + + # 1. 保存状态 + state = re_batchsampler.state_dict() + + # 2. 断点重训,继续训练 + re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler2.load_state_dict(state) + re_batchsampler2.set_epoch(0) + new_already_generate_indices = set() + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices)-before_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) + for batch in re_batchsampler2: + for b in batch: + assert b not in already_generate_indices + new_already_generate_indices.update(batch) + if drop_last is False: + assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) + + # 改变 batch_size; + after_batch_size = 3 + re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + drop_last=drop_last, + shuffle=shuffle) + re_batchsampler3.load_state_dict(state) + re_batchsampler3.set_epoch(0) + count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in re_batchsampler3: + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + # 再 save ,不允许再上个epoch没结束继续sample + after_batch_size = 5 + with pytest.raises(RuntimeError): + state = re_batchsampler3.state_dict() + + for batch in re_batchsampler3: # consume all, 这样才能save + pass + + already_generate_indices = set() + count = 0 + for batch in re_batchsampler3: # 重新开始 + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + state = re_batchsampler3.state_dict() + # 这里的 drop_last 为 False,需要最终是所有 sample + re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + drop_last=False, + shuffle=shuffle) + re_batchsampler4.load_state_dict(state) + re_batchsampler4.set_epoch(0) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + for batch in re_batchsampler4: + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + + assert len(already_generate_indices) == len(dataset) + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + def test_multi(self, shuffle, drop_last, pad): + # def test_multi(self, shuffle=True, drop_last=False, pad=False): + + # no shuffle + num_replica = 2 + dataset = DatasetWithVaryLength(num_of_data=1000) + batch_size = 5 + num_batch_per_bucket = 10 + lengths = [] + rank0_already_seen_indexes = None + max_diff = num_batch_per_bucket * batch_size * num_replica + for rank in range(num_replica): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + already_seen_indexes = set() + repeat_count = 0 + for batch in sampler: + for b in batch: + repeat_count += int(b in already_seen_indexes) + if rank0_already_seen_indexes: # 不能交叉出现 + assert b not in rank0_already_seen_indexes + already_seen_indexes.update(batch) + if rank0_already_seen_indexes is None: + rank0_already_seen_indexes = already_seen_indexes + if pad: # 应该允许重复一次 + assert repeat_count<=1 + else: + assert repeat_count==0 + + assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 + + # 多进程的保存 + already_seen_indexes = set() + for rank in range(num_replica): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + count = 0 + for batch in sampler: + already_seen_indexes.update(batch) + if count>5: + break + count += 1 + state = sampler.state_dict() + + # 切换成单机 + new_batch_size = 6 + num_batch_per_bucket = 3 + new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + shuffle=shuffle, drop_last=drop_last) + new_sampler.load_state_dict(state) + repeat_count = 0 + new_already_seen_indexes = set(list(already_seen_indexes)) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in new_sampler: + for b in batch: + repeat_count += int(b in new_already_seen_indexes) + new_already_seen_indexes.update(batch) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + if drop_last is False: # 如果没有drop应该相等 + assert len(new_already_seen_indexes)==len(dataset) + + # 测试替换卡的数量。 + num_replica = 3 + new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + shuffle=shuffle, drop_last=drop_last) + new_sampler.set_epoch(0) + new_sampler.load_state_dict(state) + new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) + repeat_count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + + for batch in new_sampler: + for b in batch: + repeat_count += int(b in already_seen_indexes) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [2, 3]) + def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas): + # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2): + dataset = DatasetWithVaryLength(num_of_data=num_samples) + batch_size = 6 + if num_replicas*batch_size > num_samples: + return + num_batch_per_bucket = 10 + samplers = [] + lengths = [] + for i in range(num_replicas): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + sampler.set_distributed(num_replicas, rank=i, pad=pad) + sampler.set_epoch(0) + samplers.append(sampler) + lengths.append(len(list(iter(sampler)))) + assert len(set(lengths))==1 + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replicas', [1, 2, 3]) + def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas): + """ + 测试是否能够正确地恢复使用过的(forward)数据 + + :return: + """ + batch_size = 6 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + samplers = [] + num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas)) + for i in range(num_replicas): + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + shuffle=shuffle, drop_last=drop_last) + + sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad) + samplers.append(sampler) + count = 0 + already_seen_sets = [set()] + already_seen_set = set() + for batchs in zip(*samplers): + batch = chain(*batchs) + already_seen_set.update(batch) + already_seen_sets.append(deepcopy(already_seen_set)) + count += 1 + if count > 3: + break + states = samplers[0].state_dict() + for i in range(len(already_seen_sets)): + states['num_consumed_samples'] = num_consumed_samples_array[i] + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1, + shuffle=shuffle, drop_last=drop_last) + sampler.set_epoch(0) + already_seen_set = deepcopy(already_seen_sets[i]) + for batch in sampler: + already_seen_set.update(batch) + assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len( + dataset) + + # 测试保存之后再次保存 + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1, + shuffle=shuffle, + drop_last=drop_last) + sampler.set_epoch(0) + states['num_consumed_samples'] = num_consumed_samples_array[2] + if len(already_seen_sets)<3: + return + already_seen_set = already_seen_sets[2] + count = 0 + for batch in sampler: + already_seen_set.update(batch) + count += 1 + if count > 6: + break + + states = sampler.state_dict() + num_consumed_samples_array = list(range(len(dataset))) + states['num_consumed_samples'] = num_consumed_samples_array[count] + sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2, + shuffle=shuffle, + drop_last=drop_last) + sampler.load_state_dict(states) + sampler.set_epoch(0) + for batch in sampler: + already_seen_set.update(batch) + + assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) diff --git a/tests/core/samplers/test_reproducible_batch_sampler_torch.py b/tests/core/samplers/test_reproducible_batch_sampler_torch.py new file mode 100644 index 00000000..af180f56 --- /dev/null +++ b/tests/core/samplers/test_reproducible_batch_sampler_torch.py @@ -0,0 +1,141 @@ +from array import array +import torch +from torch.utils.data import DataLoader + +import pytest + +from fastNLP.core.samplers import ReproduceBatchSampler +from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from tests.helpers.datasets.torch_data import TorchNormalDataset + + +@pytest.mark.torch +class TestReproducibleBatchSamplerTorch: + def test_torch_dataloader_1(self): + # no shuffle + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + next(iter_dataloader) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) + state = _get_re_batchsampler.state_dict() + assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, + "sampler_type": "ReproduceBatchSampler"} + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 改变 batch_size; + after_batch_size = 3 + dataloader = DataLoader(dataset, batch_size=after_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 断点重训的第二轮是否是一个完整的 dataloader; + # 先把断点重训所在的那一个 epoch 跑完; + begin_idx = 27 + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + # 开始新的一轮; + begin_idx = 0 + iter_dataloader = iter(dataloader) + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + def test_torch_dataloader_2(self): + # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; + from torch.utils.data import DataLoader + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + # 将一轮的所有数据保存下来,看是否恢复的是正确的; + all_supposed_data = [] + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + all_supposed_data.extend(next(iter_dataloader).tolist()) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) + state = _get_re_batchsampler.state_dict() + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + iter_dataloader = iter(dataloader) + # 先把这一轮的数据过完; + pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] + while True: + try: + all_supposed_data.extend(next(iter_dataloader).tolist()) + except StopIteration: + break + assert all_supposed_data == list(pre_index_list) + + # 重新开启新的一轮; + for _ in range(3): + iter_dataloader = iter(dataloader) + res = [] + while True: + try: + res.extend(next(iter_dataloader).tolist()) + except StopIteration: + break + assert res != all_supposed_data + diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 5657ae81..77c618bb 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -3,6 +3,7 @@ import pytest import subprocess from io import StringIO import sys +sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results from fastNLP.core import rank_zero_rm diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 395c854d..170110ce 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -1,4 +1,5 @@ import os +import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing @@ -9,7 +10,7 @@ def test_dump_fastnlp_envs(): filepath = None try: with Capturing() as output: - dump_fastnlp_backend() + dump_fastnlp_backend(backend="torch") filepath = os.path.join(os.path.expanduser('~'), '.fastNLP', 'envs', os.environ['CONDA_DEFAULT_ENV']+'.json') assert filepath in output[0] assert os.path.exists(filepath) diff --git a/tests/helpers/callbacks/helper_callbacks_torch.py b/tests/helpers/callbacks/helper_callbacks_torch.py index a197bb33..4b9730da 100644 --- a/tests/helpers/callbacks/helper_callbacks_torch.py +++ b/tests/helpers/callbacks/helper_callbacks_torch.py @@ -1,7 +1,9 @@ -import torch from copy import deepcopy from fastNLP.core.callbacks.callback import Callback +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch class RecordAccumulationStepsCallback_Torch(Callback): diff --git a/tests/helpers/datasets/normal_data.py b/tests/helpers/datasets/normal_data.py index 714ec676..b4e3ffca 100644 --- a/tests/helpers/datasets/normal_data.py +++ b/tests/helpers/datasets/normal_data.py @@ -1,13 +1,25 @@ import numpy as np +import random -class NormalIterator: - def __init__(self, num_of_data=1000): +class NormalSampler: + def __init__(self, num_of_data=1000, shuffle=False): self._num_of_data = num_of_data self._data = list(range(num_of_data)) + if shuffle: + random.shuffle(self._data) + self.shuffle = shuffle self._index = 0 + self.need_reinitialize = False def __iter__(self): + if self.need_reinitialize: + self._index = 0 + if self.shuffle: + random.shuffle(self._data) + else: + self.need_reinitialize = True + return self def __next__(self): @@ -15,12 +27,45 @@ class NormalIterator: raise StopIteration _data = self._data[self._index] self._index += 1 - return self._data + return _data def __len__(self): return self._num_of_data +class NormalBatchSampler: + def __init__(self, sampler, batch_size: int, drop_last: bool) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size + + class RandomDataset: def __init__(self, num_data=10): self.data = np.random.rand(num_data) @@ -29,4 +74,7 @@ class RandomDataset: return len(self.data) def __getitem__(self, item): - return self.data[item] \ No newline at end of file + return self.data[item] + + + diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 9a0af019..7c9056cd 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -1,7 +1,11 @@ import torch from functools import reduce -from torch.utils.data import Dataset, DataLoader, DistributedSampler -from torch.utils.data.sampler import SequentialSampler, BatchSampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import Dataset, DataLoader, DistributedSampler + from torch.utils.data.sampler import SequentialSampler, BatchSampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset class TorchNormalDataset(Dataset): diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 236ffda5..afb441ce 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -1,9 +1,14 @@ -import torch -import torch.nn as nn +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + from torch.nn import Module + import torch.nn as nn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Module # 1. 最为基础的分类模型 -class TorchNormalModel_Classification_1(nn.Module): +class TorchNormalModel_Classification_1(Module): """ 单独实现 train_step 和 evaluate_step; """ @@ -38,7 +43,7 @@ class TorchNormalModel_Classification_1(nn.Module): return {"preds": x, "target": y} -class TorchNormalModel_Classification_2(nn.Module): +class TorchNormalModel_Classification_2(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; """ @@ -62,7 +67,7 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} -class TorchNormalModel_Classification_3(nn.Module): +class TorchNormalModel_Classification_3(Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; 关闭 auto_param_call,forward 只有一个 batch 参数; diff --git a/tests/pytest.ini b/tests/pytest.ini new file mode 100644 index 00000000..d6a33a94 --- /dev/null +++ b/tests/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +markers = + torch + paddle + jittor + torchpaddle \ No newline at end of file