@@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers.utils import choose_driver | from fastNLP.core.drivers.utils import choose_driver | ||||
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME | from fastNLP.envs import FASTNLP_MODEL_FILENAME | ||||
@@ -49,13 +49,13 @@ class Driver(ABC): | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | ||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | ||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | ||||
注意当 dist 为 ReproducibleIterator, RandomBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | ||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | ||||
可以可以加载。 | 可以可以加载。 | ||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | ||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 RandomBatchSampler 需要重新初始化一个放入返回的 | |||||
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 | |||||
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 | ||||
""" | """ | ||||
if dist is None and reproducible is False: | if dist is None and reproducible is False: | ||||
@@ -3,7 +3,7 @@ from typing import Dict, Union | |||||
from .jittor_driver import JittorDriver | from .jittor_driver import JittorDriver | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
import jittor | import jittor | ||||
@@ -99,10 +99,10 @@ class JittorSingleDriver(JittorDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# reproducible 的相关功能暂时没有实现 | # reproducible 的相关功能暂时没有实现 | ||||
if isinstance(dist, RandomBatchSampler): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
dataloader.batch_sampler = dist_sample | dataloader.batch_sampler = dist_sample | ||||
if isinstance(dist, ReproducibleSampler): | if isinstance(dist, ReproducibleSampler): | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -139,12 +139,12 @@ class PaddleSingleDriver(PaddleDriver): | |||||
""" | """ | ||||
return paddle_move_data_to_device(batch, "gpu:0") | return paddle_move_data_to_device(batch, "gpu:0") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler], | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | reproducible: bool = False, sampler_or_batch_sampler=None): | ||||
# 暂时不支持IteratorDataset | # 暂时不支持IteratorDataset | ||||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | assert dataloader.dataset_kind != _DatasetKind.ITER, \ | ||||
"FastNLP does not support `IteratorDataset` now." | "FastNLP does not support `IteratorDataset` now." | ||||
if isinstance(dist, RandomBatchSampler): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dataloader.batch_sampler = dist | dataloader.batch_sampler = dist | ||||
return dataloader | return dataloader | ||||
if isinstance(dist, ReproducibleSampler): | if isinstance(dist, ReproducibleSampler): | ||||
@@ -154,11 +154,11 @@ class PaddleSingleDriver(PaddleDriver): | |||||
if reproducible: | if reproducible: | ||||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): | ||||
return dataloader | return dataloader | ||||
elif isinstance(dataloader.batch_sampler, RandomBatchSampler): | |||||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
# TODO | # TODO | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader.batch_sampler, | batch_sampler=dataloader.batch_sampler, | ||||
batch_size=dataloader.batch_sampler.batch_size, | batch_size=dataloader.batch_sampler.batch_size, | ||||
drop_last=dataloader.drop_last | drop_last=dataloader.drop_last | ||||
@@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import ( | |||||
) | ) | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.utils import auto_param_call, check_user_specific_params | from fastNLP.core.utils import auto_param_call, check_user_specific_params | ||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, RandomBatchSampler, \ | |||||
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \ | |||||
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler | ||||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -446,11 +446,11 @@ class TorchDDPDriver(TorchDriver): | |||||
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | ||||
if isinstance(dist, RandomBatchSampler): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist.set_distributed( | dist.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank, | rank=self.global_rank, | ||||
@@ -472,7 +472,7 @@ class TorchDDPDriver(TorchDriver): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, RandomBatchSampler): | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | dist = re_instantiate_sampler(dist) | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleSampler): | if isinstance(dist, ReproducibleSampler): | ||||
@@ -483,7 +483,7 @@ class TorchDDPDriver(TorchDriver): | |||||
elif dist == "dist": | elif dist == "dist": | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | ||||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | batch_sampler = re_instantiate_sampler(args.batch_sampler) | ||||
batch_sampler.set_distributed( | batch_sampler.set_distributed( | ||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
@@ -13,7 +13,7 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -129,18 +129,18 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, RandomBatchSampler): | |||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleSampler): | elif isinstance(dist, ReproducibleSampler): | ||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | ||||
args = self.get_dataloader_args(dataloader) | args = self.get_dataloader_args(dataloader) | ||||
if isinstance(args.batch_sampler, RandomBatchSampler): | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | batch_sampler = re_instantiate_sampler(args.batch_sampler) | ||||
return replace_batch_sampler(dataloader, batch_sampler) | return replace_batch_sampler(dataloader, batch_sampler) | ||||
elif isinstance(args.sampler, ReproducibleSampler): | elif isinstance(args.sampler, ReproducibleSampler): | ||||
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -183,9 +183,9 @@ class TorchDriver(Driver): | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | ||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 | ||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; | |||||
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif dataloader_args.sampler: | elif dataloader_args.sampler: | ||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
@@ -245,15 +245,14 @@ class TorchDriver(Driver): | |||||
# 3. 恢复 sampler 的状态; | # 3. 恢复 sampler 的状态; | ||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | sampler = dataloader_args.batch_sampler | ||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||||
elif isinstance(dataloader_args.sampler, ReproducibleSampler): | |||||
sampler = dataloader_args.sampler | sampler = dataloader_args.sampler | ||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||||
"`RandomBatchSampler` or `ReproducibleIterator`.") | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||||
else: | else: | ||||
sampler = RandomBatchSampler( | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -263,7 +262,7 @@ class TorchDriver(Driver): | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | ||||
if not isinstance(sampler, RandomBatchSampler): | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | if dataloader_args.drop_last: | ||||
batch_idx_in_epoch = len( | batch_idx_in_epoch = len( | ||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | ||||
@@ -19,6 +19,10 @@ __all__ = [ | |||||
"UnrepeatedSortedSampler", | "UnrepeatedSortedSampler", | ||||
"UnrepeatedSequentialSampler", | "UnrepeatedSequentialSampler", | ||||
"RandomBatchSampler", | |||||
"BucketedBatchSampler", | |||||
"ReproducibleBatchSampler", | |||||
"re_instantiate_sampler", | "re_instantiate_sampler", | ||||
"conversion_between_reproducible_and_unrepeated_sampler" | "conversion_between_reproducible_and_unrepeated_sampler" | ||||
] | ] | ||||
@@ -28,5 +32,5 @@ from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, Unre | |||||
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler | ||||
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler | ||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler | |||||
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler | |||||
@@ -17,6 +17,9 @@ from abc import abstractmethod | |||||
class ReproducibleBatchSampler: | class ReproducibleBatchSampler: | ||||
def __init__(self, **kwargs): | |||||
pass | |||||
@abstractmethod | @abstractmethod | ||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | ||||
@@ -41,6 +44,10 @@ class ReproducibleBatchSampler: | |||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
pass | pass | ||||
@property | |||||
def batch_idx_in_epoch(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") | |||||
class RandomBatchSampler(ReproducibleBatchSampler): | class RandomBatchSampler(ReproducibleBatchSampler): | ||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
@@ -54,6 +61,8 @@ class RandomBatchSampler(ReproducibleBatchSampler): | |||||
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 | ||||
:param kwargs: fastNLP 内部使用。 | :param kwargs: fastNLP 内部使用。 | ||||
""" | """ | ||||
super().__init__() | |||||
self.batch_sampler = batch_sampler | self.batch_sampler = batch_sampler | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||