|
|
@@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader): |
|
|
|
shuffle = False |
|
|
|
drop_last = False |
|
|
|
|
|
|
|
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, |
|
|
|
return_list=return_list, batch_sampler=batch_sampler, |
|
|
|
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, |
|
|
|
collate_fn=None, num_workers=num_workers, |
|
|
|
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
|
persistent_workers=persistent_workers) |
|
|
|
if isinstance(collate_fn, str): |
|
|
|
if collate_fn == 'auto': |
|
|
|
if isinstance(dataset.dataset, FDataSet): |
|
|
@@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader): |
|
|
|
|
|
|
|
else: |
|
|
|
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") |
|
|
|
elif isinstance(collate_fn, Callable): |
|
|
|
if collate_fn is not default_collate_fn: |
|
|
|
self._collate_fn = collate_fn |
|
|
|
else: |
|
|
|
self._collate_fn = default_collate_fn |
|
|
|
|
|
|
|
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, |
|
|
|
return_list=return_list, batch_sampler=batch_sampler, |
|
|
|
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, |
|
|
|
collate_fn=collate_fn, num_workers=num_workers, |
|
|
|
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
|
persistent_workers=persistent_workers) |
|
|
|
|
|
|
|
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) |
|
|
|
# if collate_fn is not None: |
|
|
|
# _collate_fn.add_collator(collate_fn) |
|
|
@@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader): |
|
|
|
# if len(self._collate_fn.get_collators()) == 0: |
|
|
|
# self._collate_fn.add_collator(default_collate_fn) |
|
|
|
# self._collate_fn = default_collate_fn |
|
|
|
self.collate_fn = indice_collate_wrapper(self._collate_fn) |
|
|
|
self.collate_fn = indice_collate_wrapper(self.collate_fn) |
|
|
|
for indices, data in super().__iter__(): |
|
|
|
self.cur_batch_indices = indices |
|
|
|
yield data |
|
|
|
|
|
|
|
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, |
|
|
|
pad_fn:Callable=None) -> Collator: |
|
|
|
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, |
|
|
|
pad_fn: Callable = None) -> Collator: |
|
|
|
""" |
|
|
|
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 |
|
|
|
|
|
|
@@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader): |
|
|
|
形式,输出将被直接作为结果输出。 |
|
|
|
:return: 返回 Collator 自身 |
|
|
|
""" |
|
|
|
if isinstance(self._collate_fn, Collator): |
|
|
|
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) |
|
|
|
return self._collate_fn |
|
|
|
collator = self._get_collator() |
|
|
|
if isinstance(collator, Collator): |
|
|
|
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) |
|
|
|
return collator |
|
|
|
else: |
|
|
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") |
|
|
|
|
|
|
|
def _get_collator(self): |
|
|
|
""" |
|
|
|
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
collator = None |
|
|
|
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): |
|
|
|
collator = self.collate_fn.__wrapped__ |
|
|
|
elif isinstance(self.collate_fn, Collator): |
|
|
|
collator = self.collate_fn |
|
|
|
return collator |
|
|
|
|
|
|
|
def set_ignore(self, *field_names) -> Collator: |
|
|
|
""" |
|
|
|
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 |
|
|
@@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader): |
|
|
|
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 |
|
|
|
:return: 返回 Collator 自身 |
|
|
|
""" |
|
|
|
if isinstance(self._collate_fn, Collator): |
|
|
|
self._collate_fn.set_ignore(*field_names) |
|
|
|
return self._collate_fn |
|
|
|
collator = self._get_collator() |
|
|
|
if isinstance(collator, Collator): |
|
|
|
collator.set_ignore(*field_names) |
|
|
|
return collator |
|
|
|
else: |
|
|
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") |
|
|
|
|
|
|
@@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, |
|
|
|
worker_init_fn: Callable = None, persistent_workers=False, |
|
|
|
non_train_batch_size: int = 16) \ |
|
|
|
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: |
|
|
|
from fastNLP.io.data_bundle import DataBundle |
|
|
|
if isinstance(ds_or_db, Dataset): |
|
|
|
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, |
|
|
|
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, |
|
|
@@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, |
|
|
|
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) |
|
|
|
return dl |
|
|
|
elif isinstance(ds_or_db, DataBundle): |
|
|
|
dl_bundle = {} |
|
|
|
for name, ds in ds_or_db.iter_datasets(): |
|
|
|
if 'train' in name: |
|
|
|
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, |
|
|
|
return_list=return_list, |
|
|
|
batch_sampler=batch_sampler, batch_size=train_batch_size, |
|
|
|
shuffle=shuffle, |
|
|
|
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, |
|
|
|
use_shared_memory=use_shared_memory, |
|
|
|
use_buffer_reader=use_buffer_reader, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
|
persistent_workers=persistent_workers) |
|
|
|
else: |
|
|
|
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, |
|
|
|
return_list=return_list, |
|
|
|
batch_sampler=batch_sampler, batch_size=non_train_batch_size, |
|
|
|
shuffle=shuffle, |
|
|
|
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, |
|
|
|
use_shared_memory=use_shared_memory, |
|
|
|
use_buffer_reader=use_buffer_reader, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn, |
|
|
|
persistent_workers=persistent_workers) |
|
|
|
return dl_bundle |
|
|
|
elif isinstance(ds_or_db, Sequence): |
|
|
|
ds_seq = [] |
|
|
|
for ds in ds_or_db: |
|
|
|