diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 11697bdc..d710f967 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.envs import rank_zero_call -from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 84c5f373..019e6fad 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -49,13 +49,13 @@ class Driver(ABC): 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; - 注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 可以可以加载。 :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, - 如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 + 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 """ if dist is None and reproducible is False: diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 0bd7188d..4c99a2f5 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -3,7 +3,7 @@ from typing import Dict, Union from .jittor_driver import JittorDriver from fastNLP.core.utils import auto_param_call from fastNLP.envs.imports import _NEED_IMPORT_JITTOR -from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler if _NEED_IMPORT_JITTOR: import jittor @@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver): def is_distributed(self): return False - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # reproducible 的相关功能暂时没有实现 - if isinstance(dist, RandomBatchSampler): + if isinstance(dist, ReproducibleBatchSampler): raise NotImplementedError dataloader.batch_sampler = dist_sample if isinstance(dist, ReproducibleSampler): diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index a124b9be..c57ba14d 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -10,7 +10,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -139,12 +139,12 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], reproducible: bool = False, sampler_or_batch_sampler=None): # 暂时不支持IteratorDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, RandomBatchSampler): + if isinstance(dist, ReproducibleBatchSampler): dataloader.batch_sampler = dist return dataloader if isinstance(dist, ReproducibleSampler): @@ -154,11 +154,11 @@ class PaddleSingleDriver(PaddleDriver): if reproducible: if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): return dataloader - elif isinstance(dataloader.batch_sampler, RandomBatchSampler): + elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): return dataloader else: # TODO - batch_sampler = RandomBatchSampler( + batch_sampler = ReproducibleBatchSampler( batch_sampler=dataloader.batch_sampler, batch_size=dataloader.batch_sampler.batch_size, drop_last=dataloader.drop_last diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index e1408df1..e19aa648 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( ) from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params -from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ +from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger @@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, reproducible: bool = False): - # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; - if isinstance(dist, RandomBatchSampler): + if isinstance(dist, ReproducibleBatchSampler): dist.set_distributed( num_replicas=self.world_size, rank=self.global_rank, @@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver): raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: - if isinstance(dist, RandomBatchSampler): + if isinstance(dist, ReproducibleBatchSampler): dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) if isinstance(dist, ReproducibleSampler): @@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver): elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.batch_sampler, RandomBatchSampler): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): batch_sampler = re_instantiate_sampler(args.batch_sampler) batch_sampler.set_distributed( num_replicas=self.world_size, diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index cf8c19a8..19e687b8 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,7 +13,7 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger @@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): - # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; - if isinstance(dist, RandomBatchSampler): + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + if isinstance(dist, ReproducibleBatchSampler): return replace_batch_sampler(dataloader, dist) elif isinstance(dist, ReproducibleSampler): return replace_sampler(dataloader, dist) # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; args = self.get_dataloader_args(dataloader) - if isinstance(args.batch_sampler, RandomBatchSampler): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): batch_sampler = re_instantiate_sampler(args.batch_sampler) return replace_batch_sampler(dataloader, batch_sampler) elif isinstance(args.sampler, ReproducibleSampler): @@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = RandomBatchSampler( + batch_sampler = ReproducibleBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index b3386f5a..b200f1fd 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler class TorchDriver(Driver): @@ -183,9 +183,9 @@ class TorchDriver(Driver): # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 - # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; + # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler elif dataloader_args.sampler: sampler = dataloader_args.sampler @@ -245,15 +245,14 @@ class TorchDriver(Driver): # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): sampler = dataloader_args.batch_sampler - elif isinstance(dataloader_args.sampler, ReproducibleIterator): + elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler elif self.is_distributed(): - raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " - "`RandomBatchSampler` or `ReproducibleIterator`.") + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: - sampler = RandomBatchSampler( + sampler = ReproducibleBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last @@ -263,7 +262,7 @@ class TorchDriver(Driver): # 4. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, RandomBatchSampler): + if not isinstance(sampler, ReproducibleBatchSampler): if dataloader_args.drop_last: batch_idx_in_epoch = len( sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index 3d6813f7..c3cc2d39 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -19,6 +19,10 @@ __all__ = [ "UnrepeatedSortedSampler", "UnrepeatedSequentialSampler", + "RandomBatchSampler", + "BucketedBatchSampler", + "ReproducibleBatchSampler", + "re_instantiate_sampler", "conversion_between_reproducible_and_unrepeated_sampler" ] @@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler -from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler +from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 5a25110b..c4116e24 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -17,6 +17,9 @@ from abc import abstractmethod class ReproducibleBatchSampler: + def __init__(self, **kwargs): + pass + @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") @@ -41,6 +44,10 @@ class ReproducibleBatchSampler: def set_epoch(self, epoch): pass + @property + def batch_idx_in_epoch(self): + raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") + class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; @@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 :param kwargs: fastNLP 内部使用。 """ + super().__init__() + self.batch_sampler = batch_sampler self.batch_size = batch_size self.drop_last = drop_last