Browse Source

1、为TorchDataLoader添加get_batch_indices函数 2、在设置sampler后将shuffle设置为False

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
aff84e5955
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 10
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -3,7 +3,7 @@ __all__ = [
'prepare_torch_dataloader'
]

from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List

from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader):

if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader):
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")

def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx

:return:
"""
return self.cur_batch_indices


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,


Loading…
Cancel
Save