@@ -12,7 +12,7 @@ if _NEED_IMPORT_JITTOR: | |||
from jittor.dataset import Dataset | |||
else: | |||
from fastNLP.core.dataset import DataSet as Dataset | |||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
@@ -63,13 +63,6 @@ class PaddleDataLoader(DataLoader): | |||
shuffle = False | |||
drop_last = False | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=None, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, FDataSet): | |||
@@ -80,11 +73,15 @@ class PaddleDataLoader(DataLoader): | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
elif isinstance(collate_fn, Callable): | |||
if collate_fn is not default_collate_fn: | |||
self._collate_fn = collate_fn | |||
else: | |||
self._collate_fn = default_collate_fn | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
# if collate_fn is not None: | |||
# _collate_fn.add_collator(collate_fn) | |||
@@ -96,13 +93,13 @@ class PaddleDataLoader(DataLoader): | |||
# if len(self._collate_fn.get_collators()) == 0: | |||
# self._collate_fn.add_collator(default_collate_fn) | |||
# self._collate_fn = default_collate_fn | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
pad_fn:Callable=None) -> Collator: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> Collator: | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
@@ -121,12 +118,26 @@ class PaddleDataLoader(DataLoader): | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
return self._collate_fn | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
return collator | |||
else: | |||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||
def _get_collator(self): | |||
""" | |||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||
:return: | |||
""" | |||
collator = None | |||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||
collator = self.collate_fn.__wrapped__ | |||
elif isinstance(self.collate_fn, Collator): | |||
collator = self.collate_fn | |||
return collator | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
@@ -138,9 +149,10 @@ class PaddleDataLoader(DataLoader): | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_ignore(*field_names) | |||
return self._collate_fn | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
collator.set_ignore(*field_names) | |||
return collator | |||
else: | |||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||
@@ -163,6 +175,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = 16) \ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
from fastNLP.io.data_bundle import DataBundle | |||
if isinstance(ds_or_db, Dataset): | |||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
@@ -170,6 +183,30 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||
return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, | |||
use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
else: | |||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||
return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, | |||
use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for ds in ds_or_db: | |||
@@ -14,7 +14,6 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
from torch.utils.data._utils.collate import default_collate | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
@@ -73,12 +72,15 @@ class TorchDataLoader(DataLoader): | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
""" | |||
if isinstance(dataset, DataSet) and collate_fn is None: | |||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
if sampler is None and batch_sampler is None: | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
shuffle=False | |||
shuffle = False | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
@@ -107,8 +109,8 @@ class TorchDataLoader(DataLoader): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
pad_fn:Callable=None) -> Collator: | |||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||
pad_fn: Callable = None) -> Collator: | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
@@ -176,9 +178,10 @@ class TorchDataLoader(DataLoader): | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||
batch_size: int = 16, | |||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
shuffle: bool = True, | |||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | |||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
@@ -207,6 +210,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
""" | |||
from fastNLP.io import DataBundle | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
@@ -218,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||
) | |||
return dl | |||
elif type(ds_or_db, DataBundle): | |||
elif isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
@@ -1,3 +1,8 @@ | |||
__all__ = [ | |||
"indice_collate_wrapper" | |||
] | |||
def indice_collate_wrapper(func): | |||
""" | |||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||
@@ -5,7 +10,7 @@ def indice_collate_wrapper(func): | |||
:param func: 需要修饰的函数 | |||
:return: | |||
""" | |||
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||
if hasattr(func, '__name__') and func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||
return func | |||
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | |||
@@ -15,16 +15,16 @@ import numpy as np | |||
from threading import Thread | |||
try: | |||
import multiprocess as mp | |||
from multiprocess import RLock | |||
import multiprocessing as mp | |||
except: | |||
pass | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.utils.utils import pretty_table_printer | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.log import logger | |||
class ApplyResultException(Exception): | |||
@@ -67,7 +67,7 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
finally: | |||
if show_progress_bar: | |||
@@ -98,6 +98,7 @@ def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: b | |||
if nums == total_len: | |||
break | |||
f_rich_progress.destroy_task(main_pro) | |||
# pb_main.close() | |||
@@ -519,7 +520,7 @@ class DataSet: | |||
# 开启进程池,线程 | |||
main_thread.start() | |||
pool = mp.Pool(processes=num_proc) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
@@ -775,3 +776,35 @@ class DataSet: | |||
if self._collator is None: | |||
self._collator = Collator() | |||
return self._collator | |||
if __name__ == '__main__': | |||
# from fastNLP import DataSet | |||
# if __name__=='__main__': | |||
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True) | |||
import multiprocess as mp | |||
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar | |||
from functools import partial | |||
from threading import Thread | |||
shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}), | |||
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})] | |||
parent, chid = mp.Pipe() | |||
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x), | |||
pipe=chid, show_progress_bar=False) | |||
thread = Thread(target=_progress_bar, args=(parent, 400, 'main')) | |||
thread.start() | |||
pool = mp.Pool(processes=6) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
thread.join() | |||
results = [] | |||
for async_result in pool_outs: | |||
data = async_result.get() | |||
results.extend(data) | |||
print(results) |
@@ -3,6 +3,10 @@ import pytest | |||
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@pytest.mark.torch | |||
@@ -22,20 +26,16 @@ class TestFdl: | |||
fdl = TorchDataLoader(ds, batch_size=3) | |||
fdl.set_pad("x", -1) | |||
for batch in fdl: | |||
print(batch) | |||
# fdl.set_pad_val("x", val=-2) | |||
# for batch in fdl: | |||
# print(batch) | |||
assert batch['x'].shape == torch.Size([3, 4]) | |||
def test_get_batch_indices(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | |||
for batch in fdl: | |||
print(fdl.get_batch_indices()) | |||
assert len(fdl.get_batch_indices()) == 3 | |||
def test_other_dataset(self): | |||
import numpy as np | |||
class _DataSet: | |||
def __init__(self): | |||
@@ -55,8 +55,41 @@ class TestFdl: | |||
# dl.set_inputs('data', 'labels') | |||
# dl.set_pad_val('labels', val=None) | |||
for batch in dl: | |||
print(batch) | |||
print(dl.get_batch_indices()) | |||
assert batch[0].shape == torch.Size([2, 5]) | |||
assert batch[1].shape == torch.Size([2, 2, 3]) | |||
def test_default_collate_fn(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
with pytest.raises(ValueError): | |||
fdl = TorchDataLoader(ds, batch_size=3, collate_fn=None) | |||
import numpy as np | |||
class _DataSet: | |||
def __init__(self): | |||
pass | |||
def __getitem__(self, item): | |||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||
def __len__(self): | |||
return 10 | |||
fdl = TorchDataLoader(_DataSet(), batch_size=3, collate_fn=None, drop_last=True) | |||
for batch in fdl: | |||
assert batch[0].shape == torch.Size([3, 5]) | |||
def test_my_collate_fn(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
def collate_fn(batch): | |||
res = {'x': [], 'y': []} | |||
for ins in batch: | |||
res['x'].append(ins['x']) | |||
res['y'].append(ins['y']) | |||
return res | |||
fdl = TorchDataLoader(ds, collate_fn=collate_fn, batch_size=3, drop_last=True) | |||
for batch in fdl: | |||
assert batch['x'] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||
assert batch['y'] == [1, 0, 1] | |||
def test_prepare_torch_dataloader(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
@@ -184,7 +184,7 @@ class TestDataSetMethods: | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||
assert ds.field_arrays["y"].content[0] == 2 | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | |||
assert (isinstance(res, list) and len(res) > 0) == True | |||
assert res[0] == 4 | |||
@@ -375,6 +375,10 @@ class TestDataSetMethods: | |||
ds.add_seq_len('x') | |||
print(ds) | |||
def test_apply_proc(self): | |||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) | |||
class TestFieldArrayInit: | |||
""" | |||