Browse Source

1、为TorchDataLoader添加get_batch_indices函数 2、在设置sampler后将shuffle设置为False

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
aff84e5955
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 10
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,7 +3,7 @@ __all__ = [
'prepare_torch_dataloader' 'prepare_torch_dataloader'
] ]


from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):


if sampler is None and batch_sampler is None: if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False


super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx

:return:
"""
return self.cur_batch_indices



def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1, batch_size: int = 1,


Loading…
Cancel
Save