|
|
@@ -3,7 +3,7 @@ __all__ = [ |
|
|
|
'prepare_torch_dataloader' |
|
|
|
] |
|
|
|
|
|
|
|
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping |
|
|
|
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.collators import Collator |
|
|
@@ -78,6 +78,7 @@ class TorchDataLoader(DataLoader): |
|
|
|
|
|
|
|
if sampler is None and batch_sampler is None: |
|
|
|
sampler = RandomSampler(dataset, shuffle=shuffle) |
|
|
|
shuffle=False |
|
|
|
|
|
|
|
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, |
|
|
|
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, |
|
|
@@ -154,6 +155,14 @@ class TorchDataLoader(DataLoader): |
|
|
|
else: |
|
|
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") |
|
|
|
|
|
|
|
def get_batch_indices(self) -> List[int]: |
|
|
|
""" |
|
|
|
获取当前 batch 的 idx |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
return self.cur_batch_indices |
|
|
|
|
|
|
|
|
|
|
|
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], |
|
|
|
batch_size: int = 1, |
|
|
|