Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
6c333b0f6e
27 changed files with 1114 additions and 628 deletions
  1. +2
    -0
      fastNLP/core/controllers/evaluator.py
  2. +1
    -2
      fastNLP/core/controllers/trainer.py
  3. +34
    -22
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  4. +2
    -2
      fastNLP/core/drivers/driver.py
  5. +2
    -2
      fastNLP/core/drivers/jittor_driver/mpi.py
  6. +6
    -6
      fastNLP/core/drivers/jittor_driver/single_device.py
  7. +5
    -5
      fastNLP/core/drivers/paddle_driver/fleet.py
  8. +4
    -4
      fastNLP/core/drivers/paddle_driver/single_device.py
  9. +26
    -15
      fastNLP/core/drivers/torch_driver/ddp.py
  10. +145
    -227
      fastNLP/core/drivers/torch_driver/dist_utils.py
  11. +4
    -5
      fastNLP/core/drivers/torch_driver/single_device.py
  12. +6
    -7
      fastNLP/core/drivers/torch_driver/torch_driver.py
  13. +17
    -7
      fastNLP/core/samplers/__init__.py
  14. +14
    -5
      fastNLP/core/samplers/reproducible_batch_sampler.py
  15. +133
    -13
      fastNLP/core/samplers/reproducible_sampler.py
  16. +42
    -13
      fastNLP/core/samplers/unrepeated_sampler.py
  17. +42
    -0
      fastNLP/core/samplers/utils.py
  18. +1
    -3
      fastNLP/core/utils/rich_progress.py
  19. +2
    -2
      tests/core/dataloaders/paddle_dataloader/test_fdl.py
  20. +2
    -2
      tests/core/dataloaders/torch_dataloader/test_fdl.py
  21. +126
    -127
      tests/core/dataset/test_dataset.py
  22. +2
    -2
      tests/core/drivers/paddle_driver/test_single_device.py
  23. +6
    -31
      tests/core/drivers/torch_driver/test_dist_utils.py
  24. +1
    -1
      tests/core/drivers/torch_driver/test_torch_replace_sampler.py
  25. +9
    -9
      tests/core/samplers/test_reproducible_batch_sampler.py
  26. +431
    -107
      tests/core/samplers/test_reproducible_sampler.py
  27. +49
    -9
      tests/core/samplers/test_unrepeated_sampler.py

+ 2
- 0
fastNLP/core/controllers/evaluator.py View File

@@ -219,6 +219,7 @@ class Evaluator:
def remove_progress_bar(self, dataloader_name): def remove_progress_bar(self, dataloader_name):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id) f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh() # 使得最终的bar可以消失
delattr(self, '_rich_task_id') delattr(self, '_rich_task_id')
elif self.progress_bar == 'raw': elif self.progress_bar == 'raw':
desc = 'Evaluation ends' desc = 'Evaluation ends'
@@ -229,6 +230,7 @@ class Evaluator:
def finally_progress_bar(self): def finally_progress_bar(self):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id) f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh()
delattr(self, '_rich_task_id') delattr(self, '_rich_task_id')


@property @property


+ 1
- 2
fastNLP/core/controllers/trainer.py View File

@@ -23,7 +23,6 @@ from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME from fastNLP.envs import FASTNLP_MODEL_FILENAME


@@ -610,7 +609,7 @@ class Trainer(TrainerEventTrigger):
r""" r"""
用于断点重训的加载函数; 用于断点重训的加载函数;
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator;
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler;


注意我们目前不支持单卡到多卡的断点重训; 注意我们目前不支持单卡到多卡的断点重训;




+ 34
- 22
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -24,6 +24,7 @@ class _FDataSet:
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
中调用dataset的方法 中调用dataset的方法
""" """

def __init__(self, dataset) -> None: def __init__(self, dataset) -> None:
self.dataset = dataset self.dataset = dataset


@@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader):
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过
提供的方法调节设置collate_fn的若干参数。 提供的方法调节设置collate_fn的若干参数。
""" """

def __init__(self, dataset, batch_size: int = 1, def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
@@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader):




def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str] = None)\
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str, None] = None) \
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
""" """
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象


@@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
dl.set_input(*input_fields)
if input_fields:
dl.set_input(*input_fields)
return dl return dl


elif isinstance(ds_or_db, DataBundle): elif isinstance(ds_or_db, DataBundle):
@@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)
return dl_bundle return dl_bundle


elif isinstance(ds_or_db, Sequence): elif isinstance(ds_or_db, Sequence):
@@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
) )
for dl in dl_bundle:
dl.set_input(*input_fields)
if input_fields:
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle return dl_bundle


elif isinstance(ds_or_db, Mapping): elif isinstance(ds_or_db, Mapping):
@@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)
else: else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
shuffle=shuffle, sampler=non_train_sampler,
batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator, multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
as_numpy=as_numpy) as_numpy=as_numpy)


dl_bundle[name].set_input(*input_fields)
if input_fields:
dl_bundle[name].set_input(*input_fields)


return dl_bundle return dl_bundle
else: else:


+ 2
- 2
fastNLP/core/drivers/driver.py View File

@@ -49,13 +49,13 @@ class Driver(ABC):
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist";
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None;
注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用;
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数;


:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得
可以可以加载。 可以可以加载。
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外,
如果传入的 dataloader 中是 ReproducibleIterator 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的
dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。
""" """
if dist is None and reproducible is False: if dist is None and reproducible is False:


+ 2
- 2
fastNLP/core/drivers/jittor_driver/mpi.py View File

@@ -3,7 +3,7 @@ from typing import Optional, Union


from .jittor_driver import JittorDriver from .jittor_driver import JittorDriver
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleIterator
from fastNLP.core.samplers import ReproducibleSampler


if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor import jittor
@@ -70,7 +70,7 @@ class JittorMPIDriver(JittorDriver):
def test_step(self, batch): def test_step(self, batch):
return self._test_step(batch) return self._test_step(batch)


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None): reproducible: bool = False, sampler_or_batch_sampler=None):
pass pass




+ 6
- 6
fastNLP/core/drivers/jittor_driver/single_device.py View File

@@ -3,7 +3,7 @@ from typing import Dict, Union
from .jittor_driver import JittorDriver from .jittor_driver import JittorDriver
from fastNLP.core.utils import auto_param_call from fastNLP.core.utils import auto_param_call
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler


if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor import jittor
@@ -99,25 +99,25 @@ class JittorSingleDriver(JittorDriver):
def is_distributed(self): def is_distributed(self):
return False return False


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None): reproducible: bool = False, sampler_or_batch_sampler=None):
# reproducible 的相关功能暂时没有实现 # reproducible 的相关功能暂时没有实现
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
raise NotImplementedError raise NotImplementedError
dataloader.batch_sampler = dist_sample dataloader.batch_sampler = dist_sample
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
raise NotImplementedError raise NotImplementedError
dataloader.batch_sampler.sampler = dist dataloader.batch_sampler.sampler = dist


if reproducible: if reproducible:
raise NotImplementedError raise NotImplementedError
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
elif isinstance(dataloader.batch_sampler, RandomBatchSampler):
return dataloader return dataloader
else: else:
# TODO # TODO
batch_sampler = ReproducibleBatchSampler(
batch_sampler = RandomBatchSampler(
batch_sampler=dataloader.batch_sampler, batch_sampler=dataloader.batch_sampler,
batch_size=dataloader.batch_sampler.batch_size, batch_size=dataloader.batch_sampler.batch_size,
drop_last=dataloader.drop_last drop_last=dataloader.drop_last


+ 5
- 5
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -19,7 +19,7 @@ from fastNLP.core.utils import (
paddle_move_data_to_device, paddle_move_data_to_device,
is_in_paddle_dist, is_in_paddle_dist,
) )
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedRandomSampler
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -312,13 +312,13 @@ class PaddleFleetDriver(PaddleDriver):
def test_step(self, batch): def test_step(self, batch):
return self._test_step(batch) return self._test_step(batch)


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler]],
reproducible: bool = False, sampler_or_batch_sampler=None): reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持iterableDataset # 暂时不支持iterableDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
"FastNLP does not support `IteratorDataset` now." "FastNLP does not support `IteratorDataset` now."
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist dataloader.batch_sampler.sampler = dist
return dataloader return dataloader


@@ -340,7 +340,7 @@ class PaddleFleetDriver(PaddleDriver):
# trainer # trainer
elif dist == "dist": elif dist == "dist":
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
dataloader.batch_sampler.sampler.set_distributed( dataloader.batch_sampler.sampler.set_distributed(
num_replicas=self.world_size, num_replicas=self.world_size,
rank=self.global_rank, rank=self.global_rank,
@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver):
return dataloader return dataloader
# evaluator # evaluator
elif dist == "unrepeatdist": elif dist == "unrepeatdist":
sampler = UnrepeatedSampler(
sampler = UnrepeatedRandomSampler(
dataset=dataloader.dataset, dataset=dataloader.dataset,
shuffle=shuffle, shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0)) seed=int(os.environ.get("FASTNLP_SEED", 0))


+ 4
- 4
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.utils import (
get_paddle_device_id, get_paddle_device_id,
paddle_move_data_to_device, paddle_move_data_to_device,
) )
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.log import logger from fastNLP.core.log import logger


if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_PADDLE:
@@ -139,7 +139,7 @@ class PaddleSingleDriver(PaddleDriver):
""" """
return paddle_move_data_to_device(batch, "gpu:0") return paddle_move_data_to_device(batch, "gpu:0")


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler],
reproducible: bool = False, sampler_or_batch_sampler=None): reproducible: bool = False, sampler_or_batch_sampler=None):
# 暂时不支持IteratorDataset # 暂时不支持IteratorDataset
assert dataloader.dataset_kind != _DatasetKind.ITER, \ assert dataloader.dataset_kind != _DatasetKind.ITER, \
@@ -147,12 +147,12 @@ class PaddleSingleDriver(PaddleDriver):
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
dataloader.batch_sampler = dist dataloader.batch_sampler = dist
return dataloader return dataloader
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
dataloader.batch_sampler.sampler = dist dataloader.batch_sampler.sampler = dist
return dataloader return dataloader


if reproducible: if reproducible:
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator):
if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler):
return dataloader return dataloader
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler):
return dataloader return dataloader


+ 26
- 15
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -28,11 +28,11 @@ from fastNLP.core.drivers.torch_driver.utils import (
) )
from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler
from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, UnrepeatedSequentialSampler, ReproducibleBatchSampler, \
re_instantiate_sampler, UnrepeatedSampler, conversion_between_reproducible_and_unrepeated_sampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
from fastNLP.core.samplers import re_instantiate_sampler




class TorchDDPDriver(TorchDriver): class TorchDDPDriver(TorchDriver):
@@ -446,13 +446,23 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch) return self._test_step(batch)


def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None,
reproducible: bool = False): reproducible: bool = False):
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用;
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist) return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist) return replace_sampler(dataloader, dist)


# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
@@ -465,7 +475,7 @@ class TorchDDPDriver(TorchDriver):
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist) dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist) return replace_batch_sampler(dataloader, dist)
if isinstance(dist, ReproducibleIterator):
if isinstance(dist, ReproducibleSampler):
dist = re_instantiate_sampler(dist) dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist) return replace_sampler(dataloader, dist)
return dataloader return dataloader
@@ -481,7 +491,7 @@ class TorchDDPDriver(TorchDriver):
pad=True pad=True
) )
return replace_batch_sampler(dataloader, batch_sampler) return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleIterator):
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler) sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed( sampler.set_distributed(
num_replicas=self.world_size, num_replicas=self.world_size,
@@ -503,14 +513,15 @@ class TorchDDPDriver(TorchDriver):
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)
# evaluator # evaluator
elif dist == "unrepeatdist": elif dist == "unrepeatdist":
# todo @yh,补充 unrepeatdist 相关内容;
args = self.get_dataloader_args(dataloader) args = self.get_dataloader_args(dataloader)

# todo 判断 batch_sampler;
sampler = UnrepeatedSampler(
dataset=args.dataset,
shuffle=args.shuffle,
)
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)
else:
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed( sampler.set_distributed(
num_replicas=self.world_size, num_replicas=self.world_size,
rank=self.global_rank rank=self.global_rank
@@ -588,7 +599,7 @@ class TorchDDPDriver(TorchDriver):
:param group: :param group:
:return: :return:
""" """
return fastnlp_torch_all_gather(obj, device=self.data_device, group=group)
return fastnlp_torch_all_gather(obj, group=group)




def find_free_network_port() -> str: def find_free_network_port() -> str:


+ 145
- 227
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -1,11 +1,8 @@
import io import io
import pickle import pickle
from typing import Mapping
_pickler = pickle.Pickler _pickler = pickle.Pickler
_unpickler = pickle.Unpickler _unpickler = pickle.Unpickler
from abc import ABC
from typing import Any, Union, List
import numpy as np
from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8




@@ -13,103 +10,25 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch
from torch import distributed as dist from torch import distributed as dist
try:
from torch._C._distributed_c10d import ProcessGroupMPI
except ImportError:
_MPI_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupNCCL
except ImportError:
_NCCL_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
_GLOO_AVAILABLE = False


from fastNLP.core.utils import apply_to_collection from fastNLP.core.utils import apply_to_collection





def all_gather_object(object_list, obj, group=None):
"""
Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.

Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.

.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.

.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.

.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
"""
if dist.distributed_c10d._rank_not_in_group(group):
return

input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
if dist.is_nccl_available() and isinstance(
group or dist.distributed_c10d._get_default_group(), dist.ProcessGroupNCCL
):
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)


def _validate_output_list_for_rank(my_rank, dst, gather_list): def _validate_output_list_for_rank(my_rank, dst, gather_list):
if dst == my_rank: if dst == my_rank:
if not gather_list: if not gather_list:
@@ -123,8 +42,10 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
) )




def gather_object(obj, object_gather_list=None, dst=0, group=None):
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None):
""" """
从其它 rank gather 东西到 dst rank 。

Gathers picklable objects from the whole group in a single process. Gathers picklable objects from the whole group in a single process.
Similar to :func:`gather`, but Python objects can be passed in. Note that the Similar to :func:`gather`, but Python objects can be passed in. Note that the
object must be picklable in order to be gathered. object must be picklable in order to be gathered.
@@ -176,6 +97,8 @@ def gather_object(obj, object_gather_list=None, dst=0, group=None):
# Ensure object_gather_list is specified appopriately. # Ensure object_gather_list is specified appopriately.
my_rank = dist.get_rank() my_rank = dist.get_rank()
_validate_output_list_for_rank(my_rank, dst, object_gather_list) _validate_output_list_for_rank(my_rank, dst, object_gather_list)
# 防止 unpickle 的时候出现在了发送的 gpu 上。
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
input_tensor, local_size = _object_to_tensor(obj) input_tensor, local_size = _object_to_tensor(obj)
group_backend = dist.get_backend(group) group_backend = dist.get_backend(group)
current_device = torch.device("cpu") current_device = torch.device("cpu")
@@ -266,113 +189,11 @@ def send_recv_object(obj, src, cur_rank, device, group=None, tag=0):
return _tensor_to_object(tensor.cpu(), size) return _tensor_to_object(tensor.cpu(), size)




def _all_gather(obj, **kwargs):
group = kwargs.get('group', None)
if isinstance(obj, torch.Tensor):
gathered_tensor = [torch.zeros_like(obj) for _ in
range(torch.distributed.get_world_size(group=group))]

torch.distributed.all_gather(gathered_tensor, obj, group=group)

return gathered_tensor

elif isinstance(obj, tuple) and isinstance(obj[1], torch.Tensor):
tensor, size = obj
# 首先需要同步 size 吧?
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=tensor.device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
dist.all_gather(object_size_list, size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=tensor.device
)

# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i: max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, tensor, group=group)
object_list = []
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
object_list.append(_tensor_to_object(tensor, tensor_size))
return object_list
elif isinstance(obj, tuple) and len(obj) == 2:
obj, _type = obj
gathered_tensor = [torch.zeros_like(obj) for _ in
range(torch.distributed.get_world_size(group=group))]

torch.distributed.all_gather(gathered_tensor, obj, group=group)

if _type == np.ndarray:
gathered_tensor = [t.detach().cpu().numpy() for t in gathered_tensor]
else:
gathered_tensor = [_type(t.item()) for t in gathered_tensor]

return gathered_tensor
else:
raise RuntimeError("Unsupported types to implement all_gather.")


class CanTransferDataType(ABC):
"""
检测可以进行传输的对象。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is CanTransferDataType:
if issubclass(subclass, Mapping):
return False
if subclass in (torch.Tensor, tuple, list, str, int, float, bool, np.ndarray):
return True
return False
return NotImplemented


def _tensorize(obj, device=None):
if isinstance(obj, torch.Tensor):
return obj
if isinstance(obj, bool):
return torch.tensor(obj, dtype=torch.uint8, device=device), bool
if isinstance(obj, float):
return torch.tensor(obj, dtype=torch.float, device=device), float
if isinstance(obj, int):
return torch.tensor(obj, dtype=torch.int, device=device), int
if isinstance(obj, np.ndarray):
return torch.from_numpy(obj), np.ndarray
return _object_to_tensor(obj, device)


def _to_device(tensor, device): def _to_device(tensor, device):
return tensor.contiguous().to(device) return tensor.contiguous().to(device)




def convert_to_tensors(data: Any, device=None) -> Any:
data = apply_to_collection(data, CanTransferDataType, _tensorize)
def _move_to_device_and_make_contiguous(t: Union[torch.Tensor, tuple], device: Union[str, torch.device]):
if isinstance(t, tuple):
if isinstance(t[1], torch.Tensor): # 说明是 object 转的
return t[0].to(device).contiguous(), t[1].to(device)
else: # 说明第二个元素是type,见 to_dtype_tensor 函数
return t[0].to(device).contiguous(), t[1]
return t.to(device).contiguous()

data = apply_to_collection(data, (torch.Tensor, tuple), _move_to_device_and_make_contiguous, device=device)
return data


def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
""" """
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。


@@ -390,36 +211,28 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List:
{'a': 1, 'b':[1, 2], 'c':{'d': 2}} {'a': 1, 'b':[1, 2], 'c':{'d': 2}}
] ]


:param obj: 任意结构的数据,所有的 value 都会变成 list ,其长度为 world_size ,依次为每个 rank 上的对象值
:param device: 当前 rank 使用的 device 是哪个。为 None 的话默认使用 torch.cuda.current_device() 获取。
:param obj: 任意结构的数据,如果为 tensor ,需要保证每个显卡上的 tensor 的形状是一样的。如果传入的是非 tensor 对象都将直接进行
序列化之后进行传输。
:param device: 当前该参数无意义。
:param group: :param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
""" """
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if device is None:
device = torch.cuda.current_device()
if _TORCH_GREATER_EQUAL_1_8:
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
else:
objs = [None for _ in range(dist.get_world_size(group))] objs = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(objs, obj)
objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上
return objs
group = group if group is not None else torch.distributed.group.WORLD
data = convert_to_tensors(obj, device=device)
data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group)

objs = []

def _get_obj_on_idx(obj, idx):
return obj[idx]

for i in range(dist.get_world_size(group)):
objs.append(apply_to_collection(data, dtype=list, function=_get_obj_on_idx, idx=i))

# 防止 unpickle 的时候弄到发送的 gpu 上了
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8:
dist.all_gather_object(objs, obj, group=group)
else:
objs = all_gather_object(objs, obj, group=group)
return objs return objs




def fastnlp_torch_broadcast_object(obj, src, device, group=None):
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None):
""" """
将 src 上的 obj 对象广播到其它 rank 上。 将 src 上的 obj 对象广播到其它 rank 上。


@@ -430,10 +243,9 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None):
:return: :return:
""" """
cur_rank = dist.get_rank(group) cur_rank = dist.get_rank(group)
# if cur_rank == src:
# # 如果有 tensor 全部移动到 cpu 上,方便 pickle
# obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))

if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu'))
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:
if cur_rank!=src: if cur_rank!=src:
get_obj = [None] get_obj = [None]
@@ -442,6 +254,8 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None):
else: else:
dist.broadcast_object_list([obj], src=src, group=group) dist.broadcast_object_list([obj], src=src, group=group)
return obj return obj
if device is None:
device = torch.cuda.current_device()


if cur_rank == src: if cur_rank == src:
tensor, size = _object_to_tensor(obj, device=device) tensor, size = _object_to_tensor(obj, device=device)
@@ -460,3 +274,107 @@ def fastnlp_torch_broadcast_object(obj, src, device, group=None):
return _tensor_to_object(tensor, tensor_size=size.item()) return _tensor_to_object(tensor, tensor_size=size.item())




def _check_for_nccl_backend(group):
pg = group or dist.distributed_c10d._get_default_group()
# It is not expected for PG to be wrapped many times, but support it just
# in case
while isinstance(pg, _ProcessGroupWrapper):
pg = pg.wrapped_pg

return (
dist.is_nccl_available() and
isinstance(pg, dist.ProcessGroupNCCL)
)


def all_gather_object(object_list, obj, group=None):
"""
复制 pytorch 的代码,使得可以版本兼容低版本的 pytorch 。

Gathers picklable objects from the whole group into a list. Similar to
:func:`all_gather`, but Python objects can be passed in. Note that the object
must be picklable in order to be gathered.

Args:
object_list (list[Any]): Output list. It should be correctly sized as the
size of the group for this collective and will contain the output.
object (Any): Pickable Python object to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used. Default is ``None``.

Returns:
None. If the calling rank is part of this group, the output of the
collective will be populated into the input ``object_list``. If the
calling rank is not part of the group, the passed in ``object_list`` will
be unmodified.

.. note:: Note that this API differs slightly from the :func:`all_gather`
collective since it does not provide an ``async_op`` handle and thus
will be a blocking call.

.. note:: For NCCL-based processed groups, internal tensor representations
of objects must be moved to the GPU device before communication takes
place. In this case, the device used is given by
``torch.cuda.current_device()`` and it is the user's responsiblity to
ensure that this is set so that each rank has an individual GPU, via
``torch.cuda.set_device()``.

.. warning::
:func:`all_gather_object` uses ``pickle`` module implicitly, which is
known to be insecure. It is possible to construct malicious pickle data
which will execute arbitrary code during unpickling. Only call this
function with data you trust.

Example::
>>> # Note: Process group initialization omitted on each rank.
>>> import torch.distributed as dist
>>> # Assumes world_size of 3.
>>> gather_objects = ["foo", 12, {1: 2}] # any picklable object
>>> output = [None for _ in gather_objects]
>>> dist.all_gather_object(output, gather_objects[dist.get_rank()])
>>> output
['foo', 12, {1: 2}]
"""
if dist._rank_not_in_group(group):
return

input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)
object_sizes_tensor = torch.zeros(
group_size, dtype=torch.long, device=current_device
)
object_size_list = [
object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size)
]
# Allgather tensor sizes
dist.all_gather(object_size_list, local_size, group=group)
max_object_size = int(max(object_size_list).item()) # type: ignore[type-var]
# Resize tensor to max size across all ranks.
input_tensor.resize_(max_object_size)
coalesced_output_tensor = torch.empty(
max_object_size * group_size, dtype=torch.uint8, device=current_device
)
# Output tensors are nonoverlapping views of coalesced_output_tensor
output_tensors = [
coalesced_output_tensor[max_object_size * i : max_object_size * (i + 1)]
for i in range(group_size)
]
dist.all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
if tensor.device != torch.device("cpu"):
tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)

+ 4
- 5
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -13,9 +13,8 @@ __all__ = [
from .torch_driver import TorchDriver from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import re_instantiate_sampler




class TorchSingleDriver(TorchDriver): class TorchSingleDriver(TorchDriver):
@@ -130,13 +129,13 @@ class TorchSingleDriver(TorchDriver):
else: else:
return self._test_step(batch) return self._test_step(batch)


def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None,
reproducible: bool = False): reproducible: bool = False):


# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用;
if isinstance(dist, ReproducibleBatchSampler): if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist) return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
elif isinstance(dist, ReproducibleSampler):
return replace_sampler(dataloader, dist) return replace_sampler(dataloader, dist)


# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用;
@@ -144,7 +143,7 @@ class TorchSingleDriver(TorchDriver):
if isinstance(args.batch_sampler, ReproducibleBatchSampler): if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler) batch_sampler = re_instantiate_sampler(args.batch_sampler)
return replace_batch_sampler(dataloader, batch_sampler) return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleIterator):
elif isinstance(args.sampler, ReproducibleSampler):
sampler = re_instantiate_sampler(args.sampler) sampler = re_instantiate_sampler(args.sampler)
return replace_sampler(dataloader, sampler) return replace_sampler(dataloader, sampler)




+ 6
- 7
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler




class TorchDriver(Driver): class TorchDriver(Driver):
@@ -182,8 +182,8 @@ class TorchDriver(Driver):
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境;


# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的
# sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`;
dataloader_args = self.get_dataloader_args(dataloader) dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler sampler = dataloader_args.batch_sampler
@@ -247,11 +247,10 @@ class TorchDriver(Driver):
dataloader_args = self.get_dataloader_args(dataloader) dataloader_args = self.get_dataloader_args(dataloader)
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler):
sampler = dataloader_args.batch_sampler sampler = dataloader_args.batch_sampler
elif isinstance(dataloader_args.sampler, ReproducibleIterator):
elif isinstance(dataloader_args.sampler, ReproducibleSampler):
sampler = dataloader_args.sampler sampler = dataloader_args.sampler
elif self.is_distributed(): elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our "
"`ReproducibleBatchSampler` or `ReproducibleIterator`.")
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else: else:
sampler = ReproducibleBatchSampler( sampler = ReproducibleBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
@@ -291,7 +290,7 @@ class TorchDriver(Driver):


@staticmethod @staticmethod
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed
with ``seed_everything(seed, workers=True)``. with ``seed_everything(seed, workers=True)``.


See also the PyTorch documentation on See also the PyTorch documentation on


+ 17
- 7
fastNLP/core/samplers/__init__.py View File

@@ -9,18 +9,28 @@ __all__ = [
'MixSequentialSampler', 'MixSequentialSampler',
'PollingSampler', 'PollingSampler',


'ReproducibleIterator',
'ReproducibleSampler',
'RandomSampler', 'RandomSampler',
're_instantiate_sampler',
"SequentialSampler",
"SortedSampler",


'UnrepeatedSampler', 'UnrepeatedSampler',
"UnrepeatedSortedSampler"
'UnrepeatedRandomSampler',
"UnrepeatedSortedSampler",
"UnrepeatedSequentialSampler",

"RandomBatchSampler",
"BucketedBatchSampler",
"ReproducibleBatchSampler",

"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
] ]


from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler



+ 14
- 5
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -1,6 +1,6 @@
__all__ = [ __all__ = [
'BucketedBatchSampler', 'BucketedBatchSampler',
"ReproducibleBatchSampler"
"RandomBatchSampler"
] ]


import math import math
@@ -16,7 +16,10 @@ from fastNLP.core.log import logger
from abc import abstractmethod from abc import abstractmethod




class ReproducibleBatchIterator:
class ReproducibleBatchSampler:
def __init__(self, **kwargs):
pass

@abstractmethod @abstractmethod
def set_distributed(self, num_replicas, rank, pad=True): def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")
@@ -41,19 +44,25 @@ class ReproducibleBatchIterator:
def set_epoch(self, epoch): def set_epoch(self, epoch):
pass pass


@property
def batch_idx_in_epoch(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `batch_idx_in_epoch` property.")


class ReproducibleBatchSampler(ReproducibleBatchIterator):

class RandomBatchSampler(ReproducibleBatchSampler):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
""" """
可以使得 batch_sampler 对象状态恢复的 wrapper 。 可以使得 batch_sampler 对象状态恢复的 wrapper 。


:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代
:param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代
出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。
:param batch_size: 每个 batch 的大小是多少。 :param batch_size: 每个 batch 的大小是多少。
:param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。
:param kwargs: fastNLP 内部使用。 :param kwargs: fastNLP 内部使用。
""" """
super().__init__()

self.batch_sampler = batch_sampler self.batch_sampler = batch_sampler
self.batch_size = batch_size self.batch_size = batch_size
self.drop_last = drop_last self.drop_last = drop_last
@@ -138,7 +147,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator):
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size




class BucketedBatchSampler(ReproducibleBatchIterator):
class BucketedBatchSampler(ReproducibleBatchSampler):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
""" """


+ 133
- 13
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -1,24 +1,21 @@
from typing import Dict, List
from typing import Dict, List, Union
import math import math
import numpy as np import numpy as np


from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.dataset import DataSet


__all__ = [ __all__ = [
'ReproducibleIterator',
'ReproducibleSampler',
'RandomSampler', 'RandomSampler',
're_instantiate_sampler'
"SortedSampler",
"SequentialSampler"
] ]




def re_instantiate_sampler(sampler):
all_attributes = vars(sampler)
return type(sampler)(**all_attributes)


class ReproducibleIterator:
class ReproducibleSampler:
""" """
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler
或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。


""" """
@@ -46,7 +43,7 @@ class ReproducibleIterator:
pass pass




class RandomSampler(ReproducibleIterator):
class RandomSampler(ReproducibleSampler):
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs):
""" """


@@ -156,8 +153,8 @@ class RandomSampler(ReproducibleIterator):
f"we cannot use {self.__class__.__name__} to load it." f"we cannot use {self.__class__.__name__} to load it."


length = states['length'] length = states['length']
assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \
"and current dataset."
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
self.seed = states['seed'] self.seed = states['seed']
self.epoch = states['epoch'] self.epoch = states['epoch']
self.num_consumed_samples = states['num_consumed_samples'] self.num_consumed_samples = states['num_consumed_samples']
@@ -214,9 +211,132 @@ class RandomSampler(ReproducibleIterator):
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))




class SequentialSampler(RandomSampler):
def __init__(self, dataset, dist_mode:str='interval', **kwargs):
"""
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。

:param dataset: 实现了 __len__ 方法的数据容器。
:param kwargs:
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)

def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = self.generate_indices()

if self.pad:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]

assert len(indices) == self.total_size

# subsample
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples


for index in indices:
self.num_consumed_samples += self.num_replicas
yield index
self.during_iter = False
self.num_consumed_samples = 0


def generate_indices(self) -> List[int]:
"""
生成随机序列


:return:
"""
return list(range(len(self.dataset)))


def state_dict(self) -> Dict:
states = {
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
}
return states


def load_state_dict(self, states: Dict):
# 如果 self.during_iter 是 True,那么 data_idx 一定是 0;
assert self.during_iter is False, "Cannot call load_state_dict() when it is " \
"during an unfinished iteration."

assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \
f"we cannot use {self.__class__.__name__} to load it."

length = states['length']
assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \
f"and current dataset({len(self.dataset)})."
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0


class SortedSampler(SequentialSampler):
def __init__(self, dataset, length:Union[str, List], **kwargs):
"""
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。

:param dataset: 实现了 __len__ 方法的数据容器。
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"

assert len(length) == len(dataset), "The length of `data` and `length` should be equal."

self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的

def generate_indices(self) -> List[int]:
return self.sorted_indices

def __iter__(self):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = self.generate_indices()

if self.pad:
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]

assert len(indices) == self.total_size

# subsample
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

for index in indices:
self.num_consumed_samples += self.num_replicas
yield index
self.during_iter = False
self.num_consumed_samples = 0



+ 42
- 13
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -1,6 +1,8 @@
__all__ = [ __all__ = [
'UnrepeatedSampler',
'UnrepeatedSortedSampler', 'UnrepeatedSortedSampler',
'UnrepeatedSampler'
'UnrepeatedRandomSampler',
"UnrepeatedSequentialSampler"
] ]


from typing import List, Union from typing import List, Union
@@ -10,13 +12,21 @@ import numpy as np




class UnrepeatedSampler: class UnrepeatedSampler:
"""
在多卡场景下保证 indice 不重复的 sampler
"""
pass


class UnrepeatedRandomSampler(UnrepeatedSampler):
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs): def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
""" """
考虑在多卡evaluate的场景下,不能重复sample。 考虑在多卡evaluate的场景下,不能重复sample。


:param dataset:
:param shuffle:
:param seed:
:param dataset: 实现了 __len__ 方法的数据容器。
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
""" """
self.dataset = dataset self.dataset = dataset
self.shuffle = shuffle self.shuffle = shuffle
@@ -33,8 +43,8 @@ class UnrepeatedSampler:
:return: :return:
""" """
num_common = len(self.dataset)//self.num_replicas num_common = len(self.dataset)//self.num_replicas
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return self.num_samples
num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return num_samples


def __iter__(self): def __iter__(self):
indices = self.generate_indices() indices = self.generate_indices()
@@ -83,8 +93,8 @@ class UnrepeatedSampler:
return self return self




class UnrepeatedSortedSampler(UnrepeatedSampler):
def __init__(self, dataset, length:Union[str, List], seed: int = 0):
class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
def __init__(self, dataset, length:Union[str, List], **kwargs):
""" """
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的 将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的
batch 数量不完全一致。 batch 数量不完全一致。
@@ -92,11 +102,9 @@ class UnrepeatedSortedSampler(UnrepeatedSampler):
:param dataset: 实现了 __len__ 方法的数据容器。 :param dataset: 实现了 __len__ 方法的数据容器。
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用 :param kwargs: fastNLP 保留使用
""" """
super().__init__(dataset=dataset, shuffle=False, seed=seed)
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
length = dataset.get_field(length) length = dataset.get_field(length)
if not isinstance(length[0], int): if not isinstance(length[0], int):
@@ -107,8 +115,29 @@ class UnrepeatedSortedSampler(UnrepeatedSampler):


assert len(length) == len(dataset), "The length of `data` and `length` should be equal." assert len(length) == len(dataset), "The length of `data` and `length` should be equal."


self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的
length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(length)[::-1].tolist() # 按长度从高到低排序的


def generate_indices(self) -> List[int]: def generate_indices(self) -> List[int]:
return self.sorted_indices return self.sorted_indices


class UnrepeatedSequentialSampler(UnrepeatedRandomSampler):
def __init__(self, dataset, **kwargs):
"""
按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。

:param dataset: 实现了 __len__ 方法的数据容器。
:param kwargs:
"""
super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs)

def __iter__(self):
indices = self.generate_indices()
indices = indices[self.rank:len(indices):self.num_replicas]
for index in indices:
yield index

def generate_indices(self) -> List[int]:
return list(range(len(self.dataset)))


+ 42
- 0
fastNLP/core/samplers/utils.py View File

@@ -0,0 +1,42 @@
__all__ = [
're_instantiate_sampler',
'conversion_between_reproducible_and_unrepeated_sampler'
]

from fastNLP.core.samplers.unrepeated_sampler import *
from fastNLP.core.samplers.reproducible_sampler import *


def conversion_between_reproducible_and_unrepeated_sampler(sampler):
"""
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler,

:param sampler:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")


def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler)
if new_sampler_class is not None:
return new_sampler_class(**all_attributes)
return type(sampler)(**all_attributes)

+ 1
- 3
fastNLP/core/utils/rich_progress.py View File

@@ -94,9 +94,6 @@ class FRichProgress(Progress, metaclass=Singleton):
self.print = self.console.print self.print = self.console.print
self.log = self.console.log self.log = self.console.log


# start new
self.start()
self.console.show_cursor(show=True)
return self return self


def set_transient(self, transient: bool = True): def set_transient(self, transient: bool = True):
@@ -154,6 +151,7 @@ class FRichProgress(Progress, metaclass=Singleton):
super().start() super().start()
self.console.show_cursor(show=True) self.console.show_cursor(show=True)



if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess( f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}", "[progress.description]{task.description}",


+ 2
- 2
tests/core/dataloaders/paddle_dataloader/test_fdl.py View File

@@ -1,4 +1,4 @@
import unittest
import pytest


from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
@@ -17,7 +17,7 @@ class RandomDataset(Dataset):
return 10 return 10




class TestPaddle(unittest.TestCase):
class TestPaddle:


def test_init(self): def test_init(self):
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10})


+ 2
- 2
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -1,11 +1,11 @@
import unittest
import pytest


from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.io.data_bundle import DataBundle from fastNLP.io.data_bundle import DataBundle




class TestFdl(unittest.TestCase):
class TestFdl:


def test_init_v1(self): def test_init_v1(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})


+ 126
- 127
tests/core/dataset/test_dataset.py View File

@@ -1,12 +1,12 @@
import os import os
import unittest
import pytest


import numpy as np import numpy as np


from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException




class TestDataSetInit(unittest.TestCase):
class TestDataSetInit:
"""初始化DataSet的办法有以下几种: """初始化DataSet的办法有以下几种:
1) 用dict: 1) 用dict:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase):
def test_init_v1(self): def test_init_v1(self):
# 一维list # 一维list
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40


def test_init_v2(self): def test_init_v2(self):
# 用dict # 用dict
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True
assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40
assert ds.field_arrays["y"].content == [[5, 6], ] * 40


def test_init_assert(self): def test_init_assert(self):
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
_ = DataSet([[1, 2, 3, 4]] * 10) _ = DataSet([[1, 2, 3, 4]] * 10)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
_ = DataSet(0.00001) _ = DataSet(0.00001)




class TestDataSetMethods(unittest.TestCase):
class TestDataSetMethods:
def test_append(self): def test_append(self):
dd = DataSet() dd = DataSet()
for _ in range(3): for _ in range(3):
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
self.assertEqual(len(dd), 3)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
assert len(dd) == 3
assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3
assert dd.field_arrays["y"].content == [[5, 6]] * 3


def test_add_field(self): def test_add_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.add_field("z", [[5, 6]] * 10) dd.add_field("z", [[5, 6]] * 10)
self.assertEqual(len(dd), 10)
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
assert len(dd) == 10
assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10
assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10
assert dd.field_arrays["z"].content == [[5, 6]] * 10


with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40) dd.add_field("??", [[1, 2]] * 40)


def test_delete_field(self): def test_delete_field(self):
@@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10)
dd.delete_field("x") dd.delete_field("x")
self.assertFalse("x" in dd.field_arrays)
self.assertTrue("y" in dd.field_arrays)
assert ("x" in dd.field_arrays) == False
assert "y" in dd.field_arrays


def test_delete_instance(self): def test_delete_instance(self):
dd = DataSet() dd = DataSet()
@@ -80,30 +80,30 @@ class TestDataSetMethods(unittest.TestCase):
dd.add_field("x", [[1, 2, 3]] * old_length) dd.add_field("x", [[1, 2, 3]] * old_length)
dd.add_field("y", [[1, 2, 3, 4]] * old_length) dd.add_field("y", [[1, 2, 3, 4]] * old_length)
dd.delete_instance(0) dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 1)
assert len(dd) == old_length - 1
dd.delete_instance(0) dd.delete_instance(0)
self.assertEqual(len(dd), old_length - 2)
assert len(dd) == old_length - 2


def test_getitem(self): def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1] ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4])
self.assertEqual(ins_0["y"], [5, 6])
assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True
assert ins_1["x"] == [1, 2, 3, 4]
assert ins_1["y"] == [5, 6]
assert ins_0["x"] == [1, 2, 3, 4]
assert ins_0["y"] == [5, 6]


sub_ds = ds[:10] sub_ds = ds[:10]
self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10)
assert isinstance(sub_ds, DataSet) == True
assert len(sub_ds) == 10


sub_ds_1 = ds[[10, 0, 2, 3]] sub_ds_1 = ds[[10, 0, 2, 3]]
self.assertTrue(isinstance(sub_ds_1, DataSet))
self.assertEqual(len(sub_ds_1), 4)
assert isinstance(sub_ds_1, DataSet) == True
assert len(sub_ds_1) == 4


field_array = ds['x'] field_array = ds['x']
self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40)
assert isinstance(field_array, FieldArray) == True
assert len(field_array) == 40


def test_setitem(self): def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
@@ -120,73 +120,73 @@ class TestDataSetMethods(unittest.TestCase):
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']


def test_get_item_error(self): def test_get_item_error(self):
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:] _ = ds[40:]


with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"] _ = ds["kom"]


def test_len_(self): def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)
assert len(ds) == 40


ds = DataSet() ds = DataSet()
self.assertEqual(len(ds), 0)
assert len(ds) == 0


def test_add_fieldarray(self): def test_add_fieldarray(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40))
self.assertEqual(ds['z'].content, [[7, 8]]*40)
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40))
assert ds['z'].content == [[7, 8]] * 40


with self.assertRaises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10))
with pytest.raises(RuntimeError):
ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10))


with self.assertRaises(TypeError):
with pytest.raises(TypeError):
ds.add_fieldarray('z', [1, 2, 4]) ds.add_fieldarray('z', [1, 2, 4])


def test_copy_field(self): def test_copy_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.copy_field('x', 'z') ds.copy_field('x', 'z')
self.assertEqual(ds['x'].content, ds['z'].content)
assert ds['x'].content == ds['z'].content


def test_has_field(self): def test_has_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue(ds.has_field('x'))
self.assertFalse(ds.has_field('z'))
assert ds.has_field('x') == True
assert ds.has_field('z') == False


def test_get_field(self): def test_get_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.get_field('z') ds.get_field('z')
x_array = ds.get_field('x') x_array = ds.get_field('x')
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40)
assert x_array.content == [[1, 2, 3, 4]] * 40


def test_get_all_fields(self): def test_get_all_fields(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_arrays = ds.get_all_fields() field_arrays = ds.get_all_fields()
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40)
self.assertEqual(field_arrays['y'], [[5, 6]] * 40)
assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40
assert field_arrays['y'].content == [[5, 6]] * 40


def test_get_field_names(self): def test_get_field_names(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
field_names = ds.get_field_names() field_names = ds.get_field_names()
self.assertTrue('x' in field_names)
self.assertTrue('y' in field_names)
assert 'x' in field_names
assert 'y' in field_names


def test_apply(self): def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx')
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
assert ("rx" in ds.field_arrays) == True
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]


ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
self.assertEqual(ds.field_arrays["y"].content[0], 2)
assert ds.field_arrays["y"].content[0] == 2


res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len")
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)
assert (isinstance(res, list) and len(res) > 0) == True
assert res[0] == 4


ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k")
# expect no exception raised # expect no exception raised
@@ -206,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase):


def modify_inplace(instance): def modify_inplace(instance):
instance['words'] = 1 instance['words'] = 1

ds.apply(modify_inplace) ds.apply(modify_inplace)
# with self.assertRaises(TypeError): # with self.assertRaises(TypeError):
# ds.apply(modify_inplace) # ds.apply(modify_inplace)
@@ -230,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase):


T.apply_more(func_1) T.apply_more(func_1)
# print(T['c'][0, 1, 2]) # print(T['c'][0, 1, 2])
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]


res = T.apply_field_more(func_2, "a", modify_fields=False) res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"].content), [2, 4, 6])
self.assertEqual(list(T["d"].content), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
assert list(T["c"].content) == [2, 4, 6]
assert list(T["d"].content) == [1, 4, 9]
assert list(res["c"]) == [3, 6, 9]
assert list(res["d"]) == [1, 8, 27]


with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_more(func_err_1) T.apply_more(func_err_1)
print(e) print(e)


with self.assertRaises(ApplyResultException) as e:
with pytest.raises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a") T.apply_field_more(func_err_2, "a")
print(e) print(e)


def test_drop(self): def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)
self.assertEqual(len(ds), 20)
assert len(ds) == 20


def test_contains(self): def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)
assert ("x" in ds) == True
assert ("y" in ds) == True
assert ("z" in ds) == False


def test_rename_field(self): def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx") ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)
assert ("xx" in ds) == True
assert ("x" in ds) == False


with self.assertRaises(KeyError):
with pytest.raises(KeyError):
ds.rename_field("yyy", "oo") ds.rename_field("yyy", "oo")


def test_split(self): def test_split(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
d1, d2 = ds.split(0.1) d1, d2 = ds.split(0.1)
self.assertEqual(len(d1), len(ds)*0.9)
self.assertEqual(len(d2), len(ds)*0.1)
assert len(d2) == (len(ds) * 0.9)
assert len(d1) == (len(ds) * 0.1)


def test_add_field_v2(self): def test_add_field_v2(self):
ds = DataSet({"x": [3, 4]}) ds = DataSet({"x": [3, 4]})
@@ -282,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase):
def test_save_load(self): def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl") ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))
assert os.path.exists("./my_ds.pkl") == True


ds_1 = DataSet.load("./my_ds.pkl") ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl") os.remove("my_ds.pkl")


def test_add_null(self): def test_add_null(self):
ds = DataSet() ds = DataSet()
with self.assertRaises(RuntimeError) as RE:
with pytest.raises(RuntimeError) as RE:
ds.add_field('test', []) ds.add_field('test', [])


def test_concat(self): def test_concat(self):
@@ -301,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase):
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]})
ds3 = ds1.concat(ds2) ds3 = ds1.concat(ds2)


self.assertEqual(len(ds3), 20)
assert len(ds3) == 20


self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4])
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1])
assert ds1[9]['x'] == [1, 2, 3, 4]
assert ds1[10]['x'] == [4, 3, 2, 1]


ds2[0]['x'][0] = 100 ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了


ds3[10]['x'][0] = -100 ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了


# 测试inplace # 测试inplace
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -318,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase):
ds3 = ds1.concat(ds2, inplace=True) ds3 = ds1.concat(ds2, inplace=True)


ds2[0]['x'][0] = 100 ds2[0]['x'][0] = 100
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了
assert ds3[10]['x'][0] == 4 # 不改变copy后的field了


ds3[10]['x'][0] = -100 ds3[10]['x'][0] = -100
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了
assert ds2[0]['x'][0] == 100 # 不改变copy前的field了


ds3[0]['x'][0] = 100 ds3[0]['x'][0] = 100
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了
assert ds1[0]['x'][0] == 100 # 改变copy前的field了


# 测试mapping # 测试mapping
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]})
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'})
self.assertEqual(len(ds3), 20)
assert len(ds3) == 20


# 测试忽略掉多余的 # 测试忽略掉多余的
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
@@ -340,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase):
# 测试报错 # 测试报错
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]})
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]})
with self.assertRaises(RuntimeError):
with pytest.raises(RuntimeError):
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x'})


def test_instance_field_disappear_bug(self): def test_instance_field_disappear_bug(self):
@@ -348,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase):
data.copy_field(field_name='raw_chars', new_field_name='chars') data.copy_field(field_name='raw_chars', new_field_name='chars')
_data = data[:1] _data = data[:1]
for field_name in ['raw_chars', 'target', 'chars']: for field_name in ['raw_chars', 'target', 'chars']:
self.assertTrue(_data.has_field(field_name))
assert _data.has_field(field_name) == True


def test_from_pandas(self): def test_from_pandas(self):
import pandas as pd import pandas as pd
@@ -356,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase):
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds = DataSet.from_pandas(df) ds = DataSet.from_pandas(df)
print(ds) print(ds)
self.assertEqual(ds['x'].content, [1, 2, 3])
self.assertEqual(ds['y'].content, [4, 5, 6])
assert ds['x'].content == [1, 2, 3]
assert ds['y'].content == [4, 5, 6]


def test_to_pandas(self): def test_to_pandas(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
@@ -366,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase):
def test_to_csv(self): def test_to_csv(self):
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
ds.to_csv("1.csv") ds.to_csv("1.csv")
self.assertTrue(os.path.exists("1.csv"))
assert os.path.exists("1.csv") == True
os.remove("1.csv") os.remove("1.csv")


def test_add_collate_fn(self): def test_add_collate_fn(self):
@@ -374,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase):


def collate_fn(item): def collate_fn(item):
return item return item
ds.add_collate_fn(collate_fn)


self.assertEqual(len(ds.collate_fns.collators), 2)
ds.add_collate_fn(collate_fn)


def test_get_collator(self): def test_get_collator(self):
from typing import Callable from typing import Callable
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]})
collate_fn = ds.get_collator() collate_fn = ds.get_collator()
self.assertEqual(isinstance(collate_fn, Callable), True)
assert isinstance(collate_fn, Callable) == True


def test_add_seq_len(self): def test_add_seq_len(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.add_seq_len('x') ds.add_seq_len('x')
print(ds) print(ds)


def test_set_target(self): def test_set_target(self):
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]})
ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]})
ds.set_target('x') ds.set_target('x')




class TestFieldArrayInit(unittest.TestCase):
class TestFieldArrayInit:
""" """
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
@@ -442,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase):
# list of array # list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])])



def test_init_v8(self): def test_init_v8(self):
# 二维list # 二维list
val = np.array([[1, 2], [3, 4]]) val = np.array([[1, 2], [3, 4]])
@@ -450,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase):
fa.append(val) fa.append(val)




class TestFieldArray(unittest.TestCase):
class TestFieldArray:
def test_main(self): def test_main(self):
fa = FieldArray("x", [1, 2, 3, 4, 5]) fa = FieldArray("x", [1, 2, 3, 4, 5])
self.assertEqual(len(fa), 5)
assert len(fa) == 5
fa.append(6) fa.append(6)
self.assertEqual(len(fa), 6)
assert len(fa) == 6


self.assertEqual(fa[-1], 6)
self.assertEqual(fa[0], 1)
assert fa[-1] == 6
assert fa[0] == 1
fa[-1] = 60 fa[-1] = 60
self.assertEqual(fa[-1], 60)
assert fa[-1] == 60


self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
assert fa.get(0) == 1
assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True
assert list(fa.get([0, 1, 2])) == [1, 2, 3]


def test_getitem_v1(self): def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]
ans = fa[[0, 1]] ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)
assert isinstance(ans, np.ndarray) == True
assert isinstance(ans[0], np.ndarray) == True
assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
assert ans[1].tolist() == [1, 2, 3, 4, 5]
assert ans.dtype == np.float64


def test_getitem_v2(self): def test_getitem_v2(self):
x = np.random.rand(10, 5) x = np.random.rand(10, 5)
fa = FieldArray("my_field", x) fa = FieldArray("my_field", x)
indices = [0, 1, 3, 4, 6] indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]): for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())
assert a.tolist() == b.tolist()


def test_append(self): def test_append(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
assert len(fa) == 3
assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6]


def test_pop(self): def test_pop(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]])
fa.pop(0) fa.pop(0)
self.assertEqual(len(fa), 1)
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0])
assert len(fa) == 1
assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0]
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5]
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5]




class TestCase(unittest.TestCase):
class TestCase:


def test_init(self): def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields)
assert isinstance(ins.fields, dict) == True
assert ins.fields == fields


ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins.fields, fields)
assert ins.fields == fields


def test_add_field(self): def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields) ins = Instance(**fields)
ins.add_field("z", [1, 1, 1]) ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]}) fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields)
assert ins.fields == fields


def test_get_item(self): def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])
assert ins["x"] == [1, 2, 3]
assert ins["y"] == [4, 5, 6]
assert ins["z"] == [1, 1, 1]


def test_repr(self): def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}


+ 2
- 2
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -10,7 +10,7 @@ from paddle.io import DataLoader, BatchSampler


from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver
from fastNLP.core.samplers.reproducible_sampler import RandomSampler from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.samplers import RandomBatchSampler
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification from tests.helpers.models.paddle_model import PaddleNormalModel_Classification
from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset from tests.helpers.datasets.paddle_data import PaddleDataset_MNIST, PaddleRandomDataset
from fastNLP.core import synchronize_safe_rm from fastNLP.core import synchronize_safe_rm
@@ -153,7 +153,7 @@ class TestSingleDeviceFunction:


@pytest.mark.parametrize( @pytest.mark.parametrize(
"dist_sampler", "dist_sampler",
["dist", ReproducibleBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))]
["dist", RandomBatchSampler(BatchSampler(PaddleDataset_MNIST("train")), 32, False), RandomSampler(PaddleDataset_MNIST("train"))]
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
"reproducible", "reproducible",


+ 6
- 31
tests/core/drivers/torch_driver/test_dist_utils.py View File

@@ -7,38 +7,10 @@ import numpy as np
# print(isinstance((1,), tuple)) # print(isinstance((1,), tuple))
# exit() # exit()


from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, convert_to_tensors, fastnlp_torch_broadcast_object
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context from tests.helpers.utils import re_run_current_cmd_for_torch, magic_argv_env_context





def test_convert_to_tensors():
local_rank = 0
obj = {
'tensor': torch.full(size=(2,), fill_value=local_rank),
'numpy': np.full(shape=(1,), fill_value=local_rank),
'bool': local_rank % 2 == 0,
'float': local_rank + 0.1,
'int': local_rank,
'dict': {
'rank': local_rank
},
'list': [local_rank] * 2,
'str': 'xxx'
}
data = convert_to_tensors(obj)
assert len(data) == len(obj)
assert (data['tensor'] == obj['tensor']).sum() == 2
for name in ['list', 'str']:
assert len(data[name])==2 and isinstance(data[name][0], torch.Tensor) and \
isinstance(data[name][1], torch.Tensor) and data[name][1].ndim==1

for name in ['numpy', 'bool', 'float', 'int']:
assert isinstance(data[name][0], torch.Tensor) and data[name][0].numel()==1

assert isinstance(data['dict']['rank'][0], torch.Tensor) and data[name][0].numel() == 1


@magic_argv_env_context @magic_argv_env_context
def test_fastnlp_torch_all_gather(): def test_fastnlp_torch_all_gather():
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
@@ -66,7 +38,7 @@ def test_fastnlp_torch_all_gather():
'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(), 'tensors': [torch.full(size=(2,), fill_value=local_rank).cuda(),
torch.full(size=(2,), fill_value=local_rank).cuda()] torch.full(size=(2,), fill_value=local_rank).cuda()]
} }
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device())
data = fastnlp_torch_all_gather(obj)
world_size = int(os.environ['WORLD_SIZE']) world_size = int(os.environ['WORLD_SIZE'])
assert len(data) == world_size assert len(data) == world_size
for i in range(world_size): for i in range(world_size):
@@ -81,10 +53,12 @@ def test_fastnlp_torch_all_gather():
assert data[i]['tensors'][0][0] == i assert data[i]['tensors'][0][0] == i


for obj in [1, True, 'xxx']: for obj in [1, True, 'xxx']:
data = fastnlp_torch_all_gather(obj, device=torch.cuda.current_device())
data = fastnlp_torch_all_gather(obj)
assert len(data)==world_size assert len(data)==world_size
assert data[0]==data[1] assert data[0]==data[1]


dist.destroy_process_group()

@magic_argv_env_context @magic_argv_env_context
def test_fastnlp_torch_broadcast_object(): def test_fastnlp_torch_broadcast_object():
os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_ADDR'] = '127.0.0.1'
@@ -130,3 +104,4 @@ def test_fastnlp_torch_broadcast_object():
for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]: for obj in [int(os.environ['LOCAL_RANK']), bool(os.environ['LOCAL_RANK']=='1'), os.environ['LOCAL_RANK']]:
data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device()) data = fastnlp_torch_broadcast_object(obj, src=0, device=torch.cuda.current_device())
assert int(data)==0 assert int(data)==0
dist.destroy_process_group()

+ 1
- 1
tests/core/drivers/torch_driver/test_torch_replace_sampler.py View File

@@ -30,7 +30,7 @@ class SequenceDataSet:




def check_replace_sampler(driver): def check_replace_sampler(driver):
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler
# dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler
# reproducible 是 True 和 False # reproducible 是 True 和 False


# 需要 check 返回的 sampler 和 dataloader 都不同了 # 需要 check 返回的 sampler 和 dataloader 都不同了


+ 9
- 9
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
from itertools import chain from itertools import chain


from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler
from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset


@@ -18,7 +18,7 @@ class TestReproducibleBatchSampler:
before_batch_size = 7 before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size) dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler) dataloader = replace_batch_sampler(dataloader, re_batchsampler)


forward_steps = 3 forward_steps = 3
@@ -28,15 +28,15 @@ class TestReproducibleBatchSampler:


# 1. 保存状态 # 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler _get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict() state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
"sampler_type": "ReproducibleBatchSampler"}
"sampler_type": "RandomBatchSampler"}


# 2. 断点重训,重新生成一个 dataloader; # 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size; # 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size) dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state) re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler) dataloader = replace_batch_sampler(dataloader, re_batchsampler)


@@ -53,7 +53,7 @@ class TestReproducibleBatchSampler:
# 改变 batch_size; # 改变 batch_size;
after_batch_size = 3 after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size) dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state) re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler) dataloader = replace_batch_sampler(dataloader, re_batchsampler)


@@ -99,7 +99,7 @@ class TestReproducibleBatchSampler:
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler) dataloader = replace_batch_sampler(dataloader, re_batchsampler)


# 将一轮的所有数据保存下来,看是否恢复的是正确的; # 将一轮的所有数据保存下来,看是否恢复的是正确的;
@@ -111,13 +111,13 @@ class TestReproducibleBatchSampler:


# 1. 保存状态 # 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler _get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict() state = _get_re_batchsampler.state_dict()


# 2. 断点重训,重新生成一个 dataloader; # 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size; # 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state) re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler) dataloader = replace_batch_sampler(dataloader, re_batchsampler)




+ 431
- 107
tests/core/samplers/test_reproducible_sampler.py View File

@@ -1,18 +1,14 @@
import unittest

from itertools import product
import numpy as np import numpy as np
import pytest


from functools import partial from functools import partial
from array import array
from itertools import chain


from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.datasets.torch_data import TorchNormalDataset





class TestRandomSamplerYh(unittest.TestCase):
class TestRandomSamplerYh:
def test_init(self): def test_init(self):
# 测试能否正确初始化 # 测试能否正确初始化
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
@@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase):
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
sampler = RandomSampler(dataset) sampler = RandomSampler(dataset)
for i in sampler: for i in sampler:
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
sampler.set_distributed(1, 0) sampler.set_distributed(1, 0)
break break


@@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase):
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
sampler = RandomSampler(dataset, shuffle=False) sampler = RandomSampler(dataset, shuffle=False)
sampler.set_distributed(num_replicas=2, rank=0, pad=False) sampler.set_distributed(num_replicas=2, rank=0, pad=False)
self.assertEqual(len(sampler), 50)
assert len(sampler)==50
count = 0 count = 0
for i in sampler: for i in sampler:
self.assertEqual(i%2, 0)
assert i%2==0
count += 1 count += 1
self.assertEqual(count, 50)
assert count == 50


sampler.set_distributed(num_replicas=2, rank=1, pad=False) sampler.set_distributed(num_replicas=2, rank=1, pad=False)
self.assertEqual(len(sampler), 50)
assert len(sampler)==50
count = 0 count = 0
for i in sampler: for i in sampler:
self.assertEqual(i%2, 1)
assert i%2==1
count += 1 count += 1
self.assertEqual(count, 50)
assert count==50


dataset = TorchNormalDataset(num_of_data=101) dataset = TorchNormalDataset(num_of_data=101)
sampler = RandomSampler(dataset, shuffle=False) sampler = RandomSampler(dataset, shuffle=False)
sampler.set_distributed(num_replicas=2, rank=0, pad=True) sampler.set_distributed(num_replicas=2, rank=0, pad=True)
self.assertEqual(len(sampler), 51)
assert len(sampler)==51
count = 0 count = 0
for i in sampler: for i in sampler:
self.assertEqual(i%2, 0)
assert i%2==0
count += 1 count += 1
self.assertEqual(count, 51)
assert count == 51


sampler.set_distributed(num_replicas=2, rank=1, pad=True) sampler.set_distributed(num_replicas=2, rank=1, pad=True)
self.assertEqual(len(sampler), 51)
assert len(sampler) == 51
count = 0 count = 0
for i in sampler: for i in sampler:
if i!=0: if i!=0:
self.assertEqual(i%2, 1)
assert i%2==1
count += 1 count += 1
self.assertEqual(count, 51)
assert count == 51


def test_state_dict_check_length(self): def test_state_dict_check_length(self):
dataset = TorchNormalDataset(num_of_data=100) dataset = TorchNormalDataset(num_of_data=100)
@@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase):
states = sampler.state_dict() states = sampler.state_dict()


new_ds = TorchNormalDataset(num_of_data=10) new_ds = TorchNormalDataset(num_of_data=10)
with self.assertRaises(AssertionError):
with pytest.raises(AssertionError):
new_sampler = RandomSampler(new_ds) new_sampler = RandomSampler(new_ds)
new_sampler.load_state_dict(states) new_sampler.load_state_dict(states)


@@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase):
new_sampler = RandomSampler(new_ds) new_sampler = RandomSampler(new_ds)
new_sampler.load_state_dict(states) new_sampler.load_state_dict(states)


def test_state_dict(self):
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('pre_shuffle', [True, False])
@pytest.mark.parametrize('post_shuffle', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
num_samples = 100 num_samples = 100
dataset = TorchNormalDataset(num_of_data=num_samples) dataset = TorchNormalDataset(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作 # 测试使用 前后shuffle不一致的load操作
lst = [0]+np.random.randint(1, num_samples, size=3).tolist()
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
lst):
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
sampler = RandomSampler(dataset, shuffle=pre_shuffle)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
self.assertEqual(len(already_numbers), num_consumed_samples)

states = sampler.state_dict()

new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
self.assertNotIn(i, already_numbers)

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
new_sampler.set_epoch(0)
count = 0
for i in new_sampler:
self.assertNotIn(i, other_rank_number)
other_rank_number.add(i)
self.assertNotIn(i, already_numbers)
count += 1

def test_state_dict_2(self):
sampler = RandomSampler(dataset, shuffle=pre_shuffle)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples

states = sampler.state_dict()

new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
for i in new_sampler:
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('pre_shuffle', [True, False])
@pytest.mark.parametrize('post_shuffle', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100 num_samples = 100
dataset = TorchNormalDataset(num_of_data=num_samples) dataset = TorchNormalDataset(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作 # 测试使用 前后shuffle不一致的load操作
lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist()
# lst = [30] # lst = [30]
for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False],
lst):
with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples):
already_numbers = set()
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
self.assertEqual(len(already_numbers), num_consumed_samples*2)

states = sampler.state_dict()

new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
self.assertNotIn(i, already_numbers)

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False)
count = 0
for i in new_sampler:
self.assertNotIn(i, other_rank_number)
other_rank_number.add(i)
self.assertNotIn(i, already_numbers)
count += 1


class TestRandomSampler(unittest.TestCase):
already_numbers = set()
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples*2

states = sampler.state_dict()

new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
for i in new_sampler:
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复


class TestRandomSampler:
# 测试单卡; # 测试单卡;
def test_seed_work_when_shuffle_is_true(self): def test_seed_work_when_shuffle_is_true(self):
data_length = 100 data_length = 100
@@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase):
... ...




class DatasetWithVaryLength:
def __init__(self, num_of_data=100, reverse=False):
self.data = np.arange(num_of_data)
if reverse:
self.data = self.data[::-1]

def __getitem__(self, item):
return self.data[item]

def __len__(self):
return len(self.data)


class TestSortedSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = SortedSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = SortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
samplers.append(sampler)

# 保证顺序是没乱的
already_seen_index = set()
for sampler in samplers:
larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。
prev_index = float('inf')
cur_set = set()
seen_in_other_rank = 0
for index in sampler:
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
cur_set.add(index)
larger_count += int(index <= prev_index)
prev_index = index
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
already_seen_index.update(cur_set)

indexes = list(chain(*samplers))
indexes = set(indexes)
if pad:
assert indexes == set(range(num_of_data))
else:
assert len(indexes) <= num_of_data

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, num_consumed_samples):
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j<max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples

states = sampler.state_dict()

new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i < max(already_numbers)
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i >= max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller<=1 if pad else smaller==0

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
# lst = [30]
already_numbers = set()
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j<=max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = SortedSampler(dataset, length=dataset.data)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples*2

states = sampler.state_dict()

new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i < max(already_numbers)
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SortedSampler(dataset, length=dataset.data)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i>=max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller <= 1 if pad else smaller == 0


class TestSequentialSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = SequentialSampler(data)
indexes = list(sampler)
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, pad, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = SequentialSampler(dataset=data)
sampler.set_distributed(num_replica, rank=i, pad=pad)
samplers.append(sampler)

# 保证顺序是没乱的
already_seen_index = set()
for idx, sampler in enumerate(samplers):
larger_count = 1
prev_index = float('inf')
cur_set = set()
seen_in_other_rank = 0
for index in sampler:
seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
cur_set.add(index)
larger_count += int(index >= prev_index)
prev_index = index
assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
already_seen_index.update(cur_set)

indexes = list(chain(*samplers))
indexes = set(indexes)
if pad:
assert indexes == set(range(num_of_data))
else:
assert len(indexes) <= num_of_data

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
def test_state_dict(self, pad, num_consumed_samples):
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
sampler = SequentialSampler(dataset=dataset)
sampler.set_epoch(0)
already_numbers = set()
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j>max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples

states = sampler.state_dict()

new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i > max(already_numbers)
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
new_sampler.set_epoch(0)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i <= max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=rank # 因为pad可能重复
assert smaller<=1 if pad else smaller==0

@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
def test_state_dict_2(self, pad, num_consumed_samples):
# 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
num_samples = 100
dataset = DatasetWithVaryLength(num_of_data=num_samples)
# 测试使用 前后shuffle不一致的load操作
# lst = [30]
already_numbers = set()
sampler = SequentialSampler(dataset=dataset)
sampler.set_distributed(num_replicas=2, rank=0)
sampler.set_epoch(0)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
if already_numbers:
assert j>max(already_numbers)
already_numbers.add(j)
if i == num_consumed_samples:
break
sampler = SequentialSampler(dataset=dataset)
sampler.set_epoch(0)
sampler.set_distributed(num_replicas=2, rank=1)
if num_consumed_samples>0:
for i, j in enumerate(sampler, start=1):
already_numbers.add(j)
if i == num_consumed_samples:
break
assert len(already_numbers) == num_consumed_samples*2

states = sampler.state_dict()

new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
for i in new_sampler:
if already_numbers:
assert i > max(already_numbers)
assert i not in already_numbers

# 测试切换成多卡也没有问题
other_rank_number = set()
for rank in range(3):
new_sampler = SequentialSampler(dataset=dataset)
new_sampler.load_state_dict(states)
new_sampler.set_epoch(0)
new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
count = 0
seen = 0
seen_in_other_rank = 0
smaller = 0
for i in new_sampler:
if already_numbers:
smaller += int(i<max(already_numbers))
seen_in_other_rank += int(i in other_rank_number)
other_rank_number.add(i)
seen += int(i in already_numbers)
count += 1
assert seen <= 1 if pad else seen == 0
assert seen_in_other_rank<=1 # 因为pad可能重复
assert smaller <= rank if pad else smaller == 0




+ 49
- 9
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -2,7 +2,7 @@ from itertools import chain


import pytest import pytest


from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler
from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler




class DatasetWithVaryLength: class DatasetWithVaryLength:
@@ -21,7 +21,7 @@ class TestUnrepeatedSampler:
def test_single(self, shuffle): def test_single(self, shuffle):
num_of_data = 100 num_of_data = 100
data = DatasetWithVaryLength(num_of_data) data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSampler(data, shuffle)
sampler = UnrepeatedRandomSampler(data, shuffle)
indexes = set(sampler) indexes = set(sampler)
assert indexes==set(range(num_of_data)) assert indexes==set(range(num_of_data))


@@ -32,17 +32,18 @@ class TestUnrepeatedSampler:
data = DatasetWithVaryLength(num_of_data=num_of_data) data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = [] samplers = []
for i in range(num_replica): for i in range(num_replica):
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle)
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
sampler.set_distributed(num_replica, rank=i) sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler) samplers.append(sampler)


indexes = set(chain(*samplers))
indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes==set(range(num_of_data)) assert indexes==set(range(num_of_data))




class TestUnrepeatedSortedSampler: class TestUnrepeatedSortedSampler:
@pytest.mark.parametrize('shuffle', [True, False])
def test_single(self, shuffle):
def test_single(self):
num_of_data = 100 num_of_data = 100
data = DatasetWithVaryLength(num_of_data) data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSortedSampler(data, length=data.data) sampler = UnrepeatedSortedSampler(data, length=data.data)
@@ -51,8 +52,7 @@ class TestUnrepeatedSortedSampler:


@pytest.mark.parametrize('num_replica', [2, 3]) @pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data) data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = [] samplers = []
for i in range(num_replica): for i in range(num_replica):
@@ -60,5 +60,45 @@ class TestUnrepeatedSortedSampler:
sampler.set_distributed(num_replica, rank=i) sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler) samplers.append(sampler)


indexes = set(chain(*samplers))
# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('inf')
for index in sampler:
assert index <= prev_index
prev_index = index

indexes = list(chain(*samplers))
assert len(indexes) == num_of_data # 不同卡之间没有交叉
indexes = set(indexes)
assert indexes==set(range(num_of_data)) assert indexes==set(range(num_of_data))


class TestUnrepeatedSequentialSampler:
def test_single(self):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSequentialSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)

# 保证顺序是没乱的
for sampler in samplers:
prev_index = float('-inf')
for index in sampler:
assert index>=prev_index
prev_index = index

indexes = list(chain(*samplers))
assert len(indexes) == num_of_data
indexes = set(indexes)
assert indexes == set(range(num_of_data))

Loading…
Cancel
Save