@@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver | |||
from fastNLP.core.drivers.utils import choose_driver | |||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | |||
@@ -49,13 +49,13 @@ class Driver(ABC): | |||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | |||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | |||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | |||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | |||
可以可以加载。 | |||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | |||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 | |||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | |||
""" | |||
if dist is None and reproducible is False: | |||
@@ -3,7 +3,7 @@ from typing import Dict, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
@@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver): | |||
def is_distributed(self): | |||
return False | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# reproducible 的相关功能暂时没有实现 | |||
if isinstance(dist, RandomBatchSampler): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist, ReproducibleSampler): | |||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||
get_paddle_device_id, | |||
paddle_move_data_to_device, | |||
) | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -139,12 +139,12 @@ class PaddleSingleDriver(PaddleDriver): | |||
""" | |||
return paddle_move_data_to_device(batch, "gpu:0") | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||
# 暂时不支持IteratorDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist, RandomBatchSampler): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dataloader.batch_sampler = dist | |||
return dataloader | |||
if isinstance(dist, ReproducibleSampler): | |||
@@ -154,11 +154,11 @@ class PaddleSingleDriver(PaddleDriver): | |||
if reproducible: | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
@@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||
) | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ | |||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | |||
from fastNLP.core.log import logger | |||
@@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver): | |||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, | |||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||
if isinstance(dist, RandomBatchSampler): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
@@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | |||
"control.") | |||
else: | |||
if isinstance(dist, RandomBatchSampler): | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
dist = re_instantiate_sampler(dist) | |||
return replace_batch_sampler(dataloader, dist) | |||
if isinstance(dist, ReproducibleSampler): | |||
@@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver): | |||
elif dist == "dist": | |||
args = self.get_dataloader_args(dataloader) | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
batch_sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
@@ -13,7 +13,7 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.log import logger | |||
@@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver): | |||
else: | |||
return self._test_step(batch) | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, | |||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||
reproducible: bool = False): | |||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, RandomBatchSampler): | |||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||
if isinstance(dist, ReproducibleBatchSampler): | |||
return replace_batch_sampler(dataloader, dist) | |||
elif isinstance(dist, ReproducibleSampler): | |||
return replace_sampler(dataloader, dist) | |||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||
args = self.get_dataloader_args(dataloader) | |||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
elif isinstance(args.sampler, ReproducibleSampler): | |||
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
class TorchDriver(Driver): | |||
@@ -183,9 +183,9 @@ class TorchDriver(Driver): | |||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | |||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; | |||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
@@ -245,15 +245,14 @@ class TorchDriver(Driver): | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||
sampler = dataloader_args.sampler | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||
"`RandomBatchSampler` or `ReproducibleIterator`.") | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = RandomBatchSampler( | |||
sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
@@ -263,7 +262,7 @@ class TorchDriver(Driver): | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if not isinstance(sampler, RandomBatchSampler): | |||
if not isinstance(sampler, ReproducibleBatchSampler): | |||
if dataloader_args.drop_last: | |||
batch_idx_in_epoch = len( | |||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
@@ -19,6 +19,10 @@ __all__ = [ | |||
"UnrepeatedSortedSampler", | |||
"UnrepeatedSequentialSampler", | |||
"RandomBatchSampler", | |||
"BucketedBatchSampler", | |||
"ReproducibleBatchSampler", | |||
"re_instantiate_sampler", | |||
"conversion_between_reproducible_and_unrepeated_sampler" | |||
] | |||
@@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre | |||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | |||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | |||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | |||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler | |||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||
@@ -17,6 +17,9 @@ from abc import abstractmethod | |||
class ReproducibleBatchSampler: | |||
def __init__(self, **kwargs): | |||
pass | |||
@abstractmethod | |||
def set_distributed(self, num_replicas, rank, pad=True): | |||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||
@@ -41,6 +44,10 @@ class ReproducibleBatchSampler: | |||
def set_epoch(self, epoch): | |||
pass | |||
@property | |||
def batch_idx_in_epoch(self): | |||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||
class RandomBatchSampler(ReproducibleBatchSampler): | |||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | |||
@@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | |||
:param kwargs: fastNLP 内部使用。 | |||
""" | |||
super().__init__() | |||
self.batch_sampler = batch_sampler | |||
self.batch_size = batch_size | |||
self.drop_last = drop_last | |||