Browse Source

增加RandomBatchSampler

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
7d5ce620f4
25 changed files with 694 additions and 183 deletions
  1. +6
    -2
      fastNLP/core/collators/collator.py
  2. +1
    -2
      fastNLP/core/collators/padders/paddle_padder.py
  3. +1
    -1
      fastNLP/core/controllers/trainer.py
  4. +26
    -27
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  5. +33
    -29
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  6. +29
    -37
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  7. +16
    -0
      fastNLP/core/dataloaders/utils.py
  8. +0
    -0
      fastNLP/core/dataloaders/utils/__init__.py
  9. +1
    -1
      fastNLP/core/dataset/dataset.py
  10. +2
    -2
      fastNLP/core/drivers/paddle_driver/fleet.py
  11. +3
    -3
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  12. +2
    -2
      fastNLP/core/drivers/paddle_driver/single_device.py
  13. +2
    -2
      fastNLP/core/drivers/torch_driver/single_device.py
  14. +3
    -3
      fastNLP/core/drivers/torch_driver/torch_driver.py
  15. +3
    -2
      fastNLP/core/samplers/__init__.py
  16. +209
    -3
      fastNLP/core/samplers/reproducible_batch_sampler.py
  17. +1
    -2
      fastNLP/core/samplers/reproducible_sampler.py
  18. +2
    -2
      fastNLP/core/utils/__init__.py
  19. +1
    -20
      fastNLP/core/utils/utils.py
  20. +15
    -15
      tests/core/drivers/paddle_driver/test_single_device.py
  21. +3
    -3
      tests/core/drivers/paddle_driver/test_utils.py
  22. +12
    -12
      tests/core/drivers/torch_driver/test_single_device.py
  23. +1
    -1
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py
  24. +3
    -3
      tests/core/drivers/torch_driver/test_utils.py
  25. +319
    -9
      tests/core/samplers/test_reproducible_batch_sampler.py

+ 6
- 2
fastNLP/core/collators/collator.py View File

@@ -65,12 +65,16 @@ def _get_backend() -> str:
return catch_backend[0] return catch_backend[0]


# 方式 (2) # 方式 (2)
for backend in CHECK_BACKEND:
if backend in sys.modules:
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.")
return backend
for key, module in sys.modules.items(): for key, module in sys.modules.items():
catch_backend = _check_module(module) catch_backend = _check_module(module)
if catch_backend: if catch_backend:
break break
if len(catch_backend): if len(catch_backend):
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
logger.debug(f"Find a module file named:{catch_backend[1]} from sys.modules contains backend:{catch_backend[0]}.")
return catch_backend[0] return catch_backend[0]


return 'numpy' return 'numpy'
@@ -227,7 +231,7 @@ class Collator:
设置可以 pad 的 field 默认 pad 为什么类型的 tensor 设置可以 pad 的 field 默认 pad 为什么类型的 tensor


:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None],
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。
若为 auto ,则在进行 pad 的时候会自动根据调用的环境决定其 backend 。
:return: :return:
""" """
assert backend in SUPPORTED_BACKENDS assert backend in SUPPORTED_BACKENDS


+ 1
- 2
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -74,7 +74,7 @@ def _get_dtype(ele_dtype, dtype, class_name):
elif is_numpy_generic_class(ele_dtype): elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_paddle_dtype_dict.get(ele_dtype) dtype = numpy_to_paddle_dtype_dict.get(ele_dtype)
else: else:
dtype == ele_dtype
dtype = ele_dtype


return dtype return dtype


@@ -174,6 +174,5 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0):
""" """
shapes = get_shape(batch_field) shapes = get_shape(batch_field)
tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype)
# tensor = paddle.full(shape=shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype) tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor return tensor

+ 1
- 1
fastNLP/core/controllers/trainer.py View File

@@ -363,7 +363,6 @@ class Trainer(TrainerEventTrigger):
raise e raise e
finally: finally:
self.on_train_end() self.on_train_end()
self.driver.barrier()


def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl):
def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None:
@@ -441,6 +440,7 @@ class Trainer(TrainerEventTrigger):
""" """
_own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"]) _own_callbacks: List = copy.deepcopy(self._custom_callbacks["all"])
_own_callbacks.extend(self._custom_callbacks[None]) _own_callbacks.extend(self._custom_callbacks[None])
logger.debug(f"Get {len(_own_callbacks)} callback fns through Trainer.on().")
self._custom_callbacks[None] = [] self._custom_callbacks[None] = []
if self.marker is not None: if self.marker is not None:
if len(self._custom_callbacks[self.marker]) == 0: if len(self._custom_callbacks[self.marker]) == 0:


+ 26
- 27
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -14,7 +14,7 @@ else:
from fastNLP.core.dataset import DataSet as Dataset from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.dataset import DataSet as FDataSet




@@ -106,33 +106,33 @@ class JittorDataLoader:
return len(self.dataset) // self.dataset.batch_size return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 return (len(self.dataset) - 1) // self.dataset.batch_size + 1


def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "JittorDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "JittorDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -145,18 +145,17 @@ class JittorDataLoader:
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]: def get_batch_indices(self) -> List[int]:
""" """
获取当前数据的idx
获取当前 batch 的 idx


:return: :return:
""" """
return self.cur_batch_indices return self.cur_batch_indices



def prepare_jittor_dataloader(): def prepare_jittor_dataloader():
... ...

+ 33
- 29
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -15,8 +15,9 @@ else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader from fastNLP.core.utils.dummy_class import DummyClass as DataLoader


from fastNLP.core.collators.collator import Collator from fastNLP.core.collators.collator import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet from fastNLP.core.dataset import DataSet as FDataSet
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler




class _PaddleDataset(Dataset): class _PaddleDataset(Dataset):
@@ -54,6 +55,10 @@ class PaddleDataLoader(DataLoader):
if not isinstance(dataset, _PaddleDataset): if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset) dataset = _PaddleDataset(dataset)


if batch_sampler is None:
batch_sampler = RandomBatchSampler(dataset, batch_size=batch_size, shuffle=shuffle,
drop_last=drop_last)

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler, return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
@@ -66,8 +71,6 @@ class PaddleDataLoader(DataLoader):
if isinstance(dataset.dataset, FDataSet): if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.collator self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="paddle") self._collate_fn.set_backend(backend="paddle")
# if collate_fn is not None:
# self._collate_fn.add_collator(collate_fn)
else: else:
self._collate_fn = Collator(backend="paddle") self._collate_fn = Collator(backend="paddle")


@@ -94,33 +97,33 @@ class PaddleDataLoader(DataLoader):
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield data yield data


def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None,
pad_fn: Callable = None) -> "PaddleDataLoader":
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn,
backend=backend)
return self
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "PaddleDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -133,13 +136,13 @@ class PaddleDataLoader(DataLoader):
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]: def get_batch_indices(self) -> List[int]:
""" """
获取当前数据的idx
获取当前 batch 的 idx


:return: :return:
""" """
@@ -147,7 +150,8 @@ class PaddleDataLoader(DataLoader):




def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
return_list: bool = True,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
train_batch_size: int = 1, shuffle: bool = False, train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, drop_last: bool = False, collate_fn: Union[Callable, str, None] = None,
num_workers: int = 0, use_buffer_reader: bool = True, num_workers: int = 0, use_buffer_reader: bool = True,


+ 29
- 37
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,14 +3,14 @@ __all__ = [
'prepare_torch_dataloader' 'prepare_torch_dataloader'
] ]


from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler from torch.utils.data import DataLoader, Sampler
@@ -76,6 +76,9 @@ class TorchDataLoader(DataLoader):
if not isinstance(dataset, _FDataSet): if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset) dataset = _FDataSet(dataset)


if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
@@ -87,9 +90,6 @@ class TorchDataLoader(DataLoader):
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch") self._collate_fn.set_backend(backend="torch")
# if collate_fn is not None and collate_fn is not default_collate:
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来
# self._collate_fn.add_collator(collate_fn)
else: else:
self._collate_fn = Collator(backend="torch") self._collate_fn = Collator(backend="torch")
else: else:
@@ -112,31 +112,32 @@ class TorchDataLoader(DataLoader):
yield data yield data


def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "TorchDataLoader":
pad_fn:Callable=None) -> Collator:
""" """
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。


:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值
无意义。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray,
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def set_ignore(self, *field_names) -> "TorchDataLoader":
def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex:: Ex::
@@ -149,24 +150,15 @@ class TorchDataLoader(DataLoader):
""" """
if isinstance(self._collate_fn, Collator): if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names) self._collate_fn.set_ignore(*field_names)
return self
return self._collate_fn
else: else:
raise ValueError(f"collate_fn is not fastnlp collator")

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx

:return:
"""
return self.cur_batch_indices

raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")




def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False, pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None, timeout: float = 0, worker_init_fn: Optional[Callable] = None,


+ 16
- 0
fastNLP/core/dataloaders/utils.py View File

@@ -0,0 +1,16 @@
def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper

+ 0
- 0
fastNLP/core/dataloaders/utils/__init__.py View File


+ 1
- 1
fastNLP/core/dataset/dataset.py View File

@@ -780,7 +780,7 @@ class DataSet:
self.collator.set_ignore(*field_names) self.collator.set_ignore(*field_names)


@property @property
def collator(self):
def collator(self) -> Collator:
if self._collator is None: if self._collator is None:
self._collator = Collator() self._collator = Collator()
return self._collator return self._collator

+ 2
- 2
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.utils import (
rank_zero_rm rank_zero_rm
) )
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler, ReproducibleSampler,
ReproducibleBatchSampler, ReproducibleBatchSampler,
RandomSampler, RandomSampler,
@@ -485,7 +485,7 @@ class PaddleFleetDriver(PaddleDriver):


return self.model, model.forward return self.model, model.forward


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproduceBatchSampler]],
reproducible: bool = False): reproducible: bool = False):
r""" r"""
根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。


+ 3
- 3
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -22,7 +22,7 @@ from fastNLP.core.log import logger
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
ReproducibleBatchSampler, ReproducibleBatchSampler,
ReproducibleSampler, ReproducibleSampler,
RandomBatchSampler,
ReproduceBatchSampler,
RandomSampler, RandomSampler,
) )


@@ -345,7 +345,7 @@ class PaddleDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.") "`ReproducibleSampler`.")
else: else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size, batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
@@ -476,7 +476,7 @@ class PaddleDriver(Driver):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"): elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler res.sampler = batch_sampler.sampler


+ 2
- 2
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.utils import (
from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ( from fastNLP.core.samplers import (
ReproducibleBatchSampler, ReproducibleBatchSampler,
RandomBatchSampler,
ReproduceBatchSampler,
ReproducibleSampler, ReproducibleSampler,
RandomSampler, RandomSampler,
re_instantiate_sampler, re_instantiate_sampler,
@@ -177,7 +177,7 @@ class PaddleSingleDriver(PaddleDriver):
logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else: else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler, batch_sampler=args.batch_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=args.drop_last drop_last=args.drop_last


+ 2
- 2
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -15,7 +15,7 @@ from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call from fastNLP.core.utils import auto_param_call
from fastNLP.core.utils.utils import _get_fun_msg from fastNLP.core.utils.utils import _get_fun_msg
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler
from fastNLP.core.samplers import RandomSampler from fastNLP.core.samplers import RandomSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -113,7 +113,7 @@ class TorchSingleDriver(TorchDriver):
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.")
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
else: else:
batch_sampler = RandomBatchSampler(
batch_sampler = ReproduceBatchSampler(
batch_sampler=args.batch_sampler, batch_sampler=args.batch_sampler,
batch_size=args.batch_size, batch_size=args.batch_size,
drop_last=args.drop_last drop_last=args.drop_last


+ 3
- 3
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,7 +31,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler




class TorchDriver(Driver): class TorchDriver(Driver):
@@ -293,7 +293,7 @@ class TorchDriver(Driver):
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or "
"`ReproducibleSampler`.") "`ReproducibleSampler`.")
else: else:
sampler = RandomBatchSampler(
sampler = ReproduceBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size, batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last drop_last=dataloader_args.drop_last
@@ -407,7 +407,7 @@ class TorchDriver(Driver):
res.shuffle = True res.shuffle = True
else: else:
res.shuffle = False res.shuffle = False
# RandomBatchSampler 的情况
# ReproduceBatchSampler 的情况
elif hasattr(dataloader.batch_sampler, "batch_sampler"): elif hasattr(dataloader.batch_sampler, "batch_sampler"):
batch_sampler = dataloader.batch_sampler.batch_sampler batch_sampler = dataloader.batch_sampler.batch_sampler
res.sampler = batch_sampler.sampler res.sampler = batch_sampler.sampler


+ 3
- 2
fastNLP/core/samplers/__init__.py View File

@@ -14,9 +14,10 @@ __all__ = [
"UnrepeatedSortedSampler", "UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler", "UnrepeatedSequentialSampler",


"RandomBatchSampler",
"ReproduceBatchSampler",
"BucketedBatchSampler", "BucketedBatchSampler",
"ReproducibleBatchSampler", "ReproducibleBatchSampler",
"RandomBatchSampler",


"re_instantiate_sampler" "re_instantiate_sampler"
] ]
@@ -26,5 +27,5 @@ from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, Polling
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproduceBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler, RandomBatchSampler



+ 209
- 3
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -1,5 +1,6 @@
__all__ = [ __all__ = [
'BucketedBatchSampler', 'BucketedBatchSampler',
"ReproduceBatchSampler",
"RandomBatchSampler" "RandomBatchSampler"
] ]


@@ -54,13 +55,13 @@ class ReproducibleBatchSampler:
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.") raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")




class RandomBatchSampler(ReproducibleBatchSampler):
class ReproduceBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
""" """
可以使得 batch_sampler 对象状态恢复的 wrapper 。 可以使得 batch_sampler 对象状态恢复的 wrapper 。


:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproduceBatchSampler 将首先遍历一边该对象,然后将迭代
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
:param batch_size: 每个 batch 的大小是多少。 :param batch_size: 每个 batch 的大小是多少。
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
@@ -143,7 +144,7 @@ class RandomBatchSampler(ReproducibleBatchSampler):
self.need_reinitialize = False self.need_reinitialize = False


def set_distributed(self, num_replicas, rank, pad=True): def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"RandomBatchSampler does not support to change to distributed training.")
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")


def set_epoch(self, epoch): def set_epoch(self, epoch):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch): if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, 'set_epoch') and callable(self.batch_sampler.sampler.set_epoch):
@@ -158,6 +159,211 @@ class RandomBatchSampler(ReproducibleBatchSampler):
(len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size (len(self.index_list) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size




class RandomBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, batch_size:int = 32, shuffle: bool = True,
drop_last: bool = False, seed: int = 0, **kwargs):
"""
随机分 batch 的 batch_sampler 。

:param dataset: 实现了 __len__ 方法的数据容器。
:param batch_size: 每个 batch 的大小
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__()

self.dataset = dataset

self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.seed = seed

self.num_consumed_samples = kwargs.get("num_consumed_samples", 0) # 总共迭代了多少数据了,包括多卡情况下的其它卡上的输出的数量

# 多卡的相关的参数
self.num_replicas = kwargs.get("num_replicas", 1)
self.rank = kwargs.get("rank", 0)
self.epoch = kwargs.get("epoch", -1)
self.pad = kwargs.get("pad", False) # 该参数在单卡上不具有任何意义;

# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)

# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)

def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
assert num_replicas > 0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0 <= rank < num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank
self.pad = pad

return self

def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True

indices = list(range(len(self.dataset)))

if self.shuffle:
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_indices = indices[_i:len(indices):self.old_num_replicas]
__batches = self.batchify(_indices, self.old_batch_size, seed=self.seed + self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
indices = list(chain(*batches))
indices = indices[self.num_consumed_samples:]
# 取出这个 rank ,
indices = indices[self.rank:len(indices):self.num_replicas]
batches = self.batchify(indices, self.batch_size, seed=self.seed + self.epoch)
batches = list(map(list, batches))
else:
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
_num_batches = len(indices) // self.batch_size
if _num_batches == 0:
batches = [indices]
else:
batches = list(map(list, np.array_split(indices[:_num_batches*self.batch_size], _num_batches)))
if len(indices)%self.batch_size!=0:
batches.append(indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)

assert sum(map(len, batches)) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
self.epoch -= 1

def batchify(self, indices, batch_size, seed):
"""
将 indices 分为 batches

:param sorted_indices: List[int]
:param batch_size: int
:param seed: int
:return: List[List[int]]
"""
# 实际的 bucket 大小
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
num_samples = 0
batches = []
while num_samples<len(indices):
batches.append(indices[num_samples:num_samples+batch_size])
num_samples += batch_size
return batches

def set_epoch(self, epoch):
self.epoch = epoch

@property
def batch_idx_in_epoch(self):
if self.drop_last:
return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size
else:
return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \
(self.num_left_samples + self.batch_size - 1) // self.batch_size

@property
def total_size(self):
"""
这个变量代表的含义是当前这个sampler会最终产生出的index数量(包括了其它rank的),因为replica和pad的原因,这个值可能等于、
大于或者小于len(dataset)

:return:
"""
return self.num_consumed_samples + self.num_replicas*self.num_left_samples

@property
def num_left_samples(self):
"""
返回当前 iteration 还有多少个 sample 结束,表示的是当前 rank 的还剩多少。

:return:
"""
num_consumed_samples = self.num_consumed_samples
return math.ceil((len(self.dataset) - num_consumed_samples) / self.num_replicas) if \
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))

def __len__(self)->int:
"""
返回当前 sampler 还会返回多少个 batch 的数据

:return:
"""
num_sampler_per_rank = self.total_size//self.num_replicas
num_batches = num_sampler_per_rank//self.batch_size if self.drop_last else \
(num_sampler_per_rank+self.batch_size-1)//self.batch_size
return num_batches

def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_replicas': self.num_replicas}

return states

def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 num_consumed_samples 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
self.seed = states['seed']
self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_replicas = states['num_replicas']


class BucketedBatchSampler(ReproducibleBatchSampler): class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):


+ 1
- 2
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -54,13 +54,12 @@ class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
""" """



:param dataset: 实现了 __len__ 方法的数据容器 :param dataset: 实现了 __len__ 方法的数据容器
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 :param shuffle: 是否在每次 iterate 的时候打乱顺序。
:param seed: 随机数种子。 :param seed: 随机数种子。
:param kwargs: 用户不需要使用,fastNLP 内部使用 :param kwargs: 用户不需要使用,fastNLP 内部使用
""" """
super(RandomSampler, self).__init__()
self.dataset = dataset self.dataset = dataset
self.shuffle = shuffle self.shuffle = shuffle
self.seed = seed self.seed = seed


+ 2
- 2
fastNLP/core/utils/__init__.py View File

@@ -21,7 +21,6 @@ __all__ = [
'nullcontext', 'nullcontext',
'pretty_table_printer', 'pretty_table_printer',
'Option', 'Option',
'indice_collate_wrapper',
'deprecated', 'deprecated',
'seq_len_to_mask', 'seq_len_to_mask',
'rank_zero_rm', 'rank_zero_rm',
@@ -37,6 +36,7 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir
from ..dataloaders.utils import indice_collate_wrapper





+ 1
- 20
fastNLP/core/utils/utils.py View File

@@ -6,7 +6,7 @@ import warnings
from dataclasses import is_dataclass from dataclasses import is_dataclass
from copy import deepcopy from copy import deepcopy
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional from typing import Tuple, Optional
from time import sleep from time import sleep


@@ -35,7 +35,6 @@ __all__ = [
'nullcontext', 'nullcontext',
'pretty_table_printer', 'pretty_table_printer',
'Option', 'Option',
'indice_collate_wrapper',
'deprecated', 'deprecated',
'seq_len_to_mask', 'seq_len_to_mask',
'rank_zero_rm', 'rank_zero_rm',
@@ -513,24 +512,6 @@ class Option(dict):
self.update(state) self.update(state)




def indice_collate_wrapper(func):
"""
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。

:param func: 需要修饰的函数
:return:
"""

def wrapper(tuple_data):
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)

return wrapper


_emitted_deprecation_warnings = set() _emitted_deprecation_warnings = set()






+ 15
- 15
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path from pathlib import Path


from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -278,7 +278,7 @@ class TestPaddleDriverFunctions:
dataset = PaddleNormalDataset() dataset = PaddleNormalDataset()
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle), BatchSampler(dataset, batch_size=batch_size, shuffle=shuffle),
batch_size, batch_size,
drop_last, drop_last,
@@ -287,7 +287,7 @@ class TestPaddleDriverFunctions:
res = PaddleSingleDriver.get_dataloader_args(dataloader) res = PaddleSingleDriver.get_dataloader_args(dataloader)


assert isinstance(res.dataset, PaddleNormalDataset) assert isinstance(res.dataset, PaddleNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle: if shuffle:
assert isinstance(res.sampler, paddle.io.RandomSampler) assert isinstance(res.sampler, paddle.io.RandomSampler)
else: else:
@@ -387,7 +387,7 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True), 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 paddle.io.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -400,7 +400,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else: else:
# 此时会替换 batch_sampler # 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -414,11 +414,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist assert replaced_loader.batch_sampler is dist


self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -450,7 +450,7 @@ class TestSetDistReproDataloader:
""" """
dataloader = DataLoader( dataloader = DataLoader(
dataset=self.dataset, dataset=self.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(self.dataset, batch_size=4, shuffle=shuffle), BatchSampler(self.dataset, batch_size=4, shuffle=shuffle),
batch_size=4, batch_size=4,
drop_last=False, drop_last=False,
@@ -459,7 +459,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -500,20 +500,20 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches: if idx >= num_consumed_batches:
break break
already_seen_idx.update(batch) already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict() sampler_states = replaced_loader.batch_sampler.state_dict()
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()


# 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader # 重新改造 dataloader
new_loader = DataLoader( new_loader = DataLoader(
dataset=replaced_loader.dataset, dataset=replaced_loader.dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size), BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size),
batch_size=batch_size, batch_size=batch_size,
drop_last=False, drop_last=False,
@@ -603,7 +603,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
dataset = PaddleRandomMaxDataset(40, 10) dataset = PaddleRandomMaxDataset(40, 10)
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=4), 4, False)
) )
driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu") driver1, driver2 = generate_random_driver(10, 10, fp16, "gpu"), generate_random_driver(10, 10, False, "gpu")


@@ -627,7 +627,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 更改 batch_size # 更改 batch_size
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False)
) )
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True)
replaced_loader = load_states.pop("dataloader") replaced_loader = load_states.pop("dataloader")
@@ -637,7 +637,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换 # 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4




+ 3
- 3
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.utils import (
replace_batch_sampler, replace_batch_sampler,
replace_sampler, replace_sampler,
) )
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
import paddle import paddle
@@ -36,12 +36,12 @@ def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices,
def test_replace_batch_sampler(): def test_replace_batch_sampler():
dataset = PaddleNormalDataset(10) dataset = PaddleNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)


replaced_loader = replace_batch_sampler(dataloader, batch_sampler) replaced_loader = replace_batch_sampler(dataloader, batch_sampler)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, PaddleNormalDataset) assert isinstance(replaced_loader.dataset, PaddleNormalDataset)
assert len(replaced_loader.dataset) == len(dataset) assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16 assert replaced_loader.batch_sampler.batch_size == 16


+ 12
- 12
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -2,7 +2,7 @@ import pytest
from pathlib import Path from pathlib import Path


from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset
from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset
@@ -17,7 +17,7 @@ if _NEED_IMPORT_PADDLE:


def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
""" """
建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader
建立一个 batch_sampler 为 ReproduceBatchSampler 的 dataloader
""" """
if shuffle: if shuffle:
sampler = torch.utils.data.RandomSampler(dataset) sampler = torch.utils.data.RandomSampler(dataset)
@@ -25,7 +25,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last):
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
dataloader = DataLoader( dataloader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=RandomBatchSampler(
batch_sampler=ReproduceBatchSampler(
BatchSampler( BatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last sampler, batch_size=batch_size, drop_last=drop_last
), ),
@@ -306,7 +306,7 @@ class TestTorchDriverFunctions:
res = TorchSingleDriver.get_dataloader_args(dataloader) res = TorchSingleDriver.get_dataloader_args(dataloader)


assert isinstance(res.dataset, TorchNormalDataset) assert isinstance(res.dataset, TorchNormalDataset)
assert isinstance(res.batch_sampler, RandomBatchSampler)
assert isinstance(res.batch_sampler, ReproduceBatchSampler)
if shuffle: if shuffle:
assert isinstance(res.sampler, torch.utils.data.RandomSampler) assert isinstance(res.sampler, torch.utils.data.RandomSampler)
else: else:
@@ -401,7 +401,7 @@ class TestSetDistReproDataloader:
""" """
测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现
当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True),
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler
只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 ReproduceBatchSampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True)
@@ -414,7 +414,7 @@ class TestSetDistReproDataloader:
assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler)
else: else:
# 此时会替换 batch_sampler # 此时会替换 batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -428,11 +428,11 @@ class TestSetDistReproDataloader:
应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler
""" """
dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle)
dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
dist = ReproduceBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False)
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler is dist assert replaced_loader.batch_sampler is dist


self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle)
@@ -466,7 +466,7 @@ class TestSetDistReproDataloader:
replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) assert not (replaced_loader.batch_sampler is dataloader.batch_sampler)
assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size
assert replaced_loader.drop_last == dataloader.drop_last assert replaced_loader.drop_last == dataloader.drop_last
@@ -502,14 +502,14 @@ class TestSetDistReproDataloader:
if idx >= num_consumed_batches: if idx >= num_consumed_batches:
break break
already_seen_idx.update(batch) already_seen_idx.update(batch)
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
sampler_states = replaced_loader.batch_sampler.state_dict() sampler_states = replaced_loader.batch_sampler.state_dict()
else: else:
sampler_states = replaced_loader.batch_sampler.sampler.state_dict() sampler_states = replaced_loader.batch_sampler.sampler.state_dict()


# 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range
left_idxes = set() left_idxes = set()
if isinstance(replaced_loader.batch_sampler, RandomBatchSampler):
if isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler):
batch_size = replaced_loader.batch_sampler.batch_size batch_size = replaced_loader.batch_sampler.batch_size
sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size
# 重新改造 dataloader # 重新改造 dataloader
@@ -613,7 +613,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
# 2. 检查 batch_sampler 是否被正确地加载和替换 # 2. 检查 batch_sampler 是否被正确地加载和替换
assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert replaced_loader.batch_sampler is dataloader.batch_sampler assert replaced_loader.batch_sampler is dataloader.batch_sampler
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"]
assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4




+ 1
- 1
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -30,7 +30,7 @@ class SequenceDataSet:




def check_replace_sampler(driver): def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproduceBatchSampler
# reproducible 是 True 和 False # reproducible 是 True 和 False


# 需要 check 返回的 sampler 和 dataloader 都不同了 # 需要 check 返回的 sampler 和 dataloader 都不同了


+ 3
- 3
tests/core/drivers/torch_driver/test_utils.py View File

@@ -4,7 +4,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler, replace_batch_sampler,
replace_sampler, replace_sampler,
) )
from fastNLP.core.samplers import RandomBatchSampler, RandomSampler
from fastNLP.core.samplers import ReproduceBatchSampler, RandomSampler
from torch.utils.data import DataLoader, BatchSampler from torch.utils.data import DataLoader, BatchSampler


from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset
@@ -14,12 +14,12 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
def test_replace_batch_sampler(): def test_replace_batch_sampler():
dataset = TorchNormalDataset(10) dataset = TorchNormalDataset(10)
dataloader = DataLoader(dataset, batch_size=32) dataloader = DataLoader(dataset, batch_size=32)
batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)
batch_sampler = ReproduceBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False)


replaced_loader = replace_batch_sampler(dataloader, batch_sampler) replaced_loader = replace_batch_sampler(dataloader, batch_sampler)


assert not (replaced_loader is dataloader) assert not (replaced_loader is dataloader)
assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler)
assert isinstance(replaced_loader.batch_sampler, ReproduceBatchSampler)
assert isinstance(replaced_loader.dataset, TorchNormalDataset) assert isinstance(replaced_loader.dataset, TorchNormalDataset)
assert len(replaced_loader.dataset) == len(dataset) assert len(replaced_loader.dataset) == len(dataset)
assert replaced_loader.batch_sampler.batch_size == 16 assert replaced_loader.batch_sampler.batch_size == 16


+ 319
- 9
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -5,7 +5,7 @@ import pytest
from itertools import chain from itertools import chain
from copy import deepcopy from copy import deepcopy


from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset


@@ -19,7 +19,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# before_batch_size = 7 # before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100) # dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size) # dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) # dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# #
# forward_steps = 3 # forward_steps = 3
@@ -29,15 +29,15 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# #
# # 1. 保存状态 # # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler # _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict() # state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, # assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
# "sampler_type": "ReproduceBatchSampler"}
# #
# # 2. 断点重训,重新生成一个 dataloader; # # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size; # # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size) # dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state) # re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) # dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# #
@@ -54,7 +54,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# # 改变 batch_size; # # 改变 batch_size;
# after_batch_size = 3 # after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size) # dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state) # re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) # dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# #
@@ -100,7 +100,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# dataset = TorchNormalDataset(num_of_data=100) # dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; # # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) # dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# #
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; # # 将一轮的所有数据保存下来,看是否恢复的是正确的;
@@ -112,13 +112,13 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset
# #
# # 1. 保存状态 # # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler # _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict() # state = _get_re_batchsampler.state_dict()
# #
# # 2. 断点重训,重新生成一个 dataloader; # # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size; # # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) # dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state) # re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) # dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# #
@@ -511,3 +511,313 @@ class TestBucketedBatchSampler:
already_seen_set.update(batch) already_seen_set.update(batch)


assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset) assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)


class TestRandomBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):

before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4

dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
already_generate_indices.update(batch)

# 1. 保存状态
state = re_batchsampler.state_dict()

# 2. 断点重训,继续训练
re_batchsampler2 = RandomBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)

# 改变 batch_size;
after_batch_size = 3
re_batchsampler3 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]

for batch in re_batchsampler3:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

# 再 save ,不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()

for batch in re_batchsampler3: # consume all, 这样才能save
pass

already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False,需要最终是所有 sample
re_batchsampler4 = RandomBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
for batch in re_batchsampler4:
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)

assert len(already_generate_indices) == len(dataset)

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):

# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0

assert len(set(lengths))==1, lengths # 每个进程的batch数量一致

# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()

# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)

# 测试替换卡的数量。
num_replica = 3
new_sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]

for batch in new_sampler:
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replicas):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replicas=2):
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replicas*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replicas, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replicas', [1, 2, 3])
def test_multi_save_load(self, shuffle, drop_last, pad, num_samples, num_replicas):
"""
测试是否能够正确地恢复使用过的(forward)数据

:return:
"""
batch_size = 6
dataset = DatasetWithVaryLength(num_of_data=num_samples)
samplers = []
num_consumed_samples_array = list(range(0, num_samples+num_replicas, num_replicas))
for i in range(num_replicas):
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
shuffle=shuffle, drop_last=drop_last)

sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
samplers.append(sampler)
count = 0
already_seen_sets = [set()]
already_seen_set = set()
for batchs in zip(*samplers):
batch = chain(*batchs)
already_seen_set.update(batch)
already_seen_sets.append(deepcopy(already_seen_set))
count += 1
if count > 3:
break
states = samplers[0].state_dict()
for i in range(len(already_seen_sets)):
states['num_consumed_samples'] = num_consumed_samples_array[i]
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size+1,
shuffle=shuffle, drop_last=drop_last)
sampler.set_epoch(0)
already_seen_set = deepcopy(already_seen_sets[i])
for batch in sampler:
already_seen_set.update(batch)
assert len(already_seen_set) == len(dataset) if drop_last is False else len(already_seen_set) <= len(
dataset)

# 测试保存之后再次保存
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size + 1,
shuffle=shuffle,
drop_last=drop_last)
sampler.set_epoch(0)
states['num_consumed_samples'] = num_consumed_samples_array[2]
if len(already_seen_sets)<3:
return
already_seen_set = already_seen_sets[2]
count = 0
for batch in sampler:
already_seen_set.update(batch)
count += 1
if count > 6:
break

states = sampler.state_dict()
num_consumed_samples_array = list(range(len(dataset)))
states['num_consumed_samples'] = num_consumed_samples_array[count]
sampler = RandomBatchSampler(dataset, length=dataset.data, batch_size=batch_size//2,
shuffle=shuffle,
drop_last=drop_last)
sampler.load_state_dict(states)
sampler.set_epoch(0)
for batch in sampler:
already_seen_set.update(batch)

assert len(already_seen_set)==len(dataset) if drop_last is False else len(already_seen_set)<=len(dataset)

Loading…
Cancel
Save