@@ -0,0 +1,14 @@ | |||||
__all__ = [ | |||||
'MixDataLoader', | |||||
'TorchDataLoader', | |||||
'PaddleDataLoader', | |||||
'JittorDataLoader', | |||||
'prepare_jittor_dataloader', | |||||
'prepare_paddle_dataloader', | |||||
'prepare_torch_dataloader' | |||||
] | |||||
from .mix_dataloader import MixDataLoader | |||||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,7 @@ | |||||
__all__ = [ | |||||
'FDataLoader' | |||||
] | |||||
class FDataLoader: | |||||
pass |
@@ -0,0 +1,7 @@ | |||||
__all__ = [ | |||||
"JittorDataLoader", | |||||
'prepare_jittor_dataloader' | |||||
] | |||||
from .fdl import JittorDataLoader, prepare_jittor_dataloader |
@@ -0,0 +1,138 @@ | |||||
__all__ = [ | |||||
'JittorDataLoader', | |||||
'prepare_jittor_dataloader' | |||||
] | |||||
from typing import Callable, Optional, List | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
from jittor.dataset.utils import collate_batch | |||||
from jittor.dataset import Dataset | |||||
else: | |||||
from fastNLP.core.dataset import DataSet as Dataset | |||||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | |||||
class _JittorDataset(Dataset): | |||||
""" | |||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset | |||||
""" | |||||
def __init__(self, dataset) -> None: | |||||
super(_JittorDataset, self).__init__() | |||||
self.dataset = dataset | |||||
def __getitem__(self, item): | |||||
return (item, self.dataset[item]) | |||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
# def __getattr__(self, item): | |||||
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 | |||||
# try: | |||||
# self.dataset.__getattribute__(item) | |||||
# except Exception as e: | |||||
# raise e | |||||
class JittorDataLoader: | |||||
""" | |||||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | |||||
""" | |||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||||
collate_fn: Callable = None) -> None: | |||||
""" | |||||
:param dataset: 实现__getitem__和__len__的dataset | |||||
:param batch_size: 批次大小 | |||||
:param shuffle: 是否打乱数据集 | |||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||||
:param buffer_size: | |||||
:param stop_grad: | |||||
:param keep_numpy_array: | |||||
:param endless: | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||||
""" | |||||
# TODO 支持fastnlp dataset | |||||
# TODO 验证支持replacesampler (以后完成) | |||||
# 是否为 jittor 类型的 dataset | |||||
if isinstance(dataset, FDataSet): | |||||
collator = dataset.get_collator().set_as_numpy(as_numpy=True) | |||||
else: | |||||
collator = None | |||||
self.dataset = _JittorDataset(dataset) | |||||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | |||||
keep_numpy_array=keep_numpy_array, endless=endless) | |||||
if isinstance(self.dataset.dataset, Dataset): | |||||
self.dataset.dataset.set_attrs(batch_size=1) | |||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||||
self.collate_fn = collate_fn | |||||
if self.collate_fn is None: | |||||
self.collate_fn = collate_batch | |||||
self.auto_collator = collator | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | |||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||||
if self.cur_batch_indices is None: | |||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, | |||||
self.auto_collator))) | |||||
for indices, data in self.dataset.__iter__(): | |||||
self.cur_batch_indices = indices | |||||
yield data | |||||
def __len__(self): | |||||
if self.dataset.drop_last: | |||||
return len(self.dataset) // self.dataset.batch_size | |||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
""" | |||||
if self.auto_collator is None: | |||||
self.auto_collator = AutoCollator(as_numpy=True) | |||||
self.auto_collator.set_pad_val(*field_names, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
if self.auto_collator is None: | |||||
self.auto_collator = AutoCollator(as_numpy=True) | |||||
self.auto_collator.set_input(*field_names) | |||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前数据的idx | |||||
:return: | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_jittor_dataloader(): | |||||
... |
@@ -0,0 +1,194 @@ | |||||
__all__ = [ | |||||
'MixDataLoader' | |||||
] | |||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||||
import numpy as np | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader, Sampler | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||||
class _MixDataset: | |||||
""" | |||||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||||
""" | |||||
def __init__(self, datasets: list = None) -> None: | |||||
""" | |||||
:param datasets: 数据集的列表 | |||||
""" | |||||
self.datasets = datasets | |||||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | |||||
self.lens = [] | |||||
index = 0 | |||||
for item in self.datasets: | |||||
index += len(item) | |||||
self.lens.append(index) | |||||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | |||||
""" | |||||
:param idx: | |||||
:return: | |||||
""" | |||||
if isinstance(idx, int): | |||||
if idx >= self.lens[-1]: | |||||
raise ValueError(f"idx: {idx} out of range") | |||||
# 找到其属于哪个数据集,返回下标 | |||||
ds_index = np.searchsorted(self.lens, idx, side='right') | |||||
if ds_index > 0: | |||||
idx -= self.lens[ds_index - 1] | |||||
return self.datasets[ds_index][idx], ds_index | |||||
elif isinstance(idx, list): | |||||
# 一般一个list列表只能是属于一种数据的,否则会报错 | |||||
dataset = DataSet() | |||||
ds_index = 0 | |||||
for i in idx: | |||||
assert isinstance(i, int), "Only int index allowed." | |||||
instance, ds_index = self[i] | |||||
dataset.append(instance) | |||||
return dataset, ds_index | |||||
else: | |||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||||
def __len__(self) -> int: | |||||
return self.lens[-1] | |||||
class _MixCollateFn: | |||||
""" | |||||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | |||||
""" | |||||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||||
auto_collators: Optional[List[Callable]] = None) -> None: | |||||
if isinstance(collate_fns, Sequence): | |||||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | |||||
elif callable(collate_fns): | |||||
self.collate_fns = lambda idx, lst: collate_fns(lst) | |||||
else: | |||||
self.collate_fns = lambda idx, lst: lst | |||||
self.collate_fns = collate_fns | |||||
self.auto_collators = auto_collators | |||||
def __call__(self, ins_list: List) -> Dict: | |||||
""" | |||||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | |||||
:param ins_list: | |||||
:return: | |||||
""" | |||||
_ins_list, _ds_index = [], 0 | |||||
for ins, _ds_index in ins_list: | |||||
_ins_list.append(ins) | |||||
# auto_collate先处理 | |||||
if self.auto_collators is not None: | |||||
_ins_list = self.auto_collators[_ds_index](_ins_list) | |||||
_ins_list = self.collate_fns(_ds_index, _ins_list) | |||||
return _ins_list | |||||
class MixDataLoader(DataLoader): | |||||
""" | |||||
针对一下三种情况提供的MixDataLoader: | |||||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||||
""" | |||||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||||
pin_memory: bool = True) -> None: | |||||
""" | |||||
:param datasets: dataset的列表 | |||||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||||
对象,每次迭代返回的是List[int] | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数, | |||||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||||
其对应关系与datasets位置匹配; | |||||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||||
sampler为List[Sampler],则datasets也为List,且一一对应 | |||||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||||
其大于0,可以超过1,将数据集重采样翻倍即可 | |||||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||||
""" | |||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||||
isinstance(sampler, Dict): | |||||
raise ValueError(f"") | |||||
if isinstance(collate_fn, list): | |||||
if len(collate_fn) != len(datasets): | |||||
raise ValueError("the length of collate_fn != datasets!!") | |||||
if isinstance(sampler, list): | |||||
if len(sampler) != len(datasets): | |||||
raise ValueError("the length of sampler != datasets!!") | |||||
# Dict类型转化为List,以便于_MixCollateFn处理 | |||||
if isinstance(collate_fn, Dict): | |||||
collate_fn = [fn for _, fn in collate_fn.items()] | |||||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||||
if isinstance(datasets, Dict): | |||||
dataset = [ds for _, ds in datasets.items()] | |||||
else: | |||||
dataset = datasets | |||||
auto_collators = [] | |||||
for per_ds in dataset: | |||||
if isinstance(per_ds, DataSet): | |||||
auto_collators.append(per_ds.get_collator()) | |||||
else: | |||||
# 如果没有对应的collator就设置一个不做任何操作的collator | |||||
auto_collators.append(lambda x: x) | |||||
# List类型的collate_fn只有两种情况,需要对其进行包裹 | |||||
collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||||
if mode == 'sequential': | |||||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | |||||
drop_last=drop_last, ds_ratio=ds_ratio) | |||||
elif mode == 'polling': | |||||
batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler, | |||||
drop_last=drop_last, ds_ratio=ds_ratio) | |||||
elif mode == 'mix': | |||||
batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler, | |||||
drop_last=drop_last, ds_ratio=ds_ratio) | |||||
elif isinstance(mode, Sampler): | |||||
batch_sampler = mode | |||||
else: | |||||
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | |||||
super(MixDataLoader, self).__init__( | |||||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||||
prefetch_factor=2, persistent_workers=False | |||||
) | |||||
def __iter__(self): | |||||
return super().__iter__() |
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
'prepare_paddle_dataloader', | |||||
'PaddleDataLoader' | |||||
] | |||||
from .fdl import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,192 @@ | |||||
__all__ = [ | |||||
'PaddleDataLoader', | |||||
'prepare_paddle_dataloader' | |||||
] | |||||
from typing import Callable, List, Optional, Union, Dict, Sequence | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | |||||
from paddle.io import DataLoader, Dataset | |||||
from paddle.fluid.dataloader.collate import default_collate_fn | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.core.dataset import DataSet as FDataSet | |||||
class _PaddleDataset(Dataset): | |||||
""" | |||||
对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader | |||||
""" | |||||
def __init__(self, dataset) -> None: | |||||
super(_PaddleDataset, self).__init__() | |||||
self.dataset = dataset | |||||
def __getitem__(self, item): | |||||
return (item, self.dataset[item]) | |||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
def __getattr__(self, item): | |||||
try: | |||||
self.dataset.__getattribute__(item) | |||||
except Exception as e: | |||||
raise e | |||||
class PaddleDataLoader(DataLoader): | |||||
def __init__(self, dataset, feed_list=None, places=None, | |||||
return_list: bool = True, batch_sampler=None, | |||||
batch_size: int = 1, shuffle: bool = False, | |||||
drop_last: bool = False, collate_fn: Callable = None, | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | |||||
use_shared_memory: bool = True, timeout: int = 0, | |||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | |||||
if not isinstance(dataset, _PaddleDataset): | |||||
dataset = _PaddleDataset(dataset) | |||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||||
return_list=return_list, batch_sampler=batch_sampler, | |||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
collate_fn=None, num_workers=num_workers, | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
if isinstance(dataset.dataset, FDataSet): | |||||
self._collate_fn = dataset.dataset.get_collator() | |||||
self._collate_fn.set_as_numpy(as_numpy=True) | |||||
if collate_fn is not None: | |||||
self._collate_fn.add_collator(collate_fn) | |||||
else: | |||||
self._collate_fn = _MultiCollator(collate_fn) | |||||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||||
# if collate_fn is not None: | |||||
# _collate_fn.add_collator(collate_fn) | |||||
# self._collate_fn = _collate_fn | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | |||||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||||
if len(self._collate_fn.get_collators()) == 0: | |||||
self._collate_fn.add_collator(default_collate_fn) | |||||
# self._collate_fn = default_collate_fn | |||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||||
for indices, data in super().__iter__(): | |||||
self.cur_batch_indices = indices | |||||
yield data | |||||
def __getattr__(self, item): | |||||
""" | |||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||||
:param item: | |||||
:return: | |||||
""" | |||||
try: | |||||
return self.dataset.__getattr__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
""" | |||||
for field_name in field_names: | |||||
self._collate_fn.set_pad_val(field_name, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
self._collate_fn.set_input(*field_names) | |||||
def set_collator(self, collator: Callable) -> None: | |||||
""" | |||||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||||
:param collator: 用户自定义的Callable函数 | |||||
:return: | |||||
""" | |||||
self._collate_fn = _MultiCollator(collator) | |||||
def add_collator(self, collator) -> None: | |||||
""" | |||||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||||
:param collator: | |||||
:return: | |||||
""" | |||||
self._collate_fn.add_collator(collator) | |||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前数据的idx | |||||
:return: | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
return_list: bool = True, batch_sampler=None, | |||||
train_batch_size: int = 1, shuffle: bool = False, | |||||
drop_last: bool = False, collate_fn: Callable = None, | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | |||||
use_shared_memory: bool = True, timeout: int = 0, | |||||
worker_init_fn: Callable = None, persistent_workers=False, | |||||
non_train_batch_size: int = 16, | |||||
input_fields: Union[List[str], str] = None)\ | |||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||||
if isinstance(input_fields, str): | |||||
input_fields = [input_fields] | |||||
if isinstance(ds_or_db, Dataset): | |||||
... | |||||
elif isinstance(ds_or_db, Sequence): | |||||
ds_seq = [] | |||||
for ds in ds_or_db: | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
dl.set_input(*input_fields) | |||||
ds_seq.append(dl) | |||||
return ds_seq | |||||
elif isinstance(ds_or_db, Dict): | |||||
ds_dict = {} | |||||
for name, ds in ds_or_db.items(): | |||||
if 'train' in name: | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
else: | |||||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||||
dl.set_input(*input_fields) | |||||
ds_dict[name] = dl | |||||
return ds_dict | |||||
else: | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |
@@ -0,0 +1,6 @@ | |||||
__all__ = [ | |||||
"TorchDataLoader", | |||||
"prepare_torch_dataloader" | |||||
] | |||||
from .fdl import TorchDataLoader, prepare_torch_dataloader |
@@ -0,0 +1,300 @@ | |||||
__all__ = [ | |||||
'TorchDataLoader', | |||||
'prepare_torch_dataloader' | |||||
] | |||||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.collators import AutoCollator | |||||
from fastNLP.core.collators.collator import _MultiCollator | |||||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader, Sampler | |||||
from torch.utils.data._utils.collate import default_collate | |||||
else: | |||||
from ..fdataloader import FDataLoader as DataLoader | |||||
class _FDataSet: | |||||
""" | |||||
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset | |||||
中调用dataset的方法 | |||||
""" | |||||
def __init__(self, dataset) -> None: | |||||
self.dataset = dataset | |||||
def __getitem__(self, item: Union[int, list]) -> Tuple: | |||||
return (item, self.dataset[item]) | |||||
def __getattr__(self, item): | |||||
try: | |||||
return self.dataset.__getattribute__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
class TorchDataLoader(DataLoader): | |||||
""" | |||||
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 | |||||
提供的方法调节设置collate_fn的若干参数。 | |||||
""" | |||||
def __init__(self, dataset, batch_size: int = 1, | |||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||||
""" | |||||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||||
:param shuffle: 是否打乱数据集 | |||||
:param sampler: sampler实例化对象 | |||||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param pin_memory: | |||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param timeout: | |||||
:param worker_init_fn: | |||||
:param multiprocessing_context: | |||||
:param generator: | |||||
:param prefetch_factor: | |||||
:param persistent_workers: | |||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||||
""" | |||||
if not isinstance(dataset, _FDataSet): | |||||
dataset = _FDataSet(dataset) | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||||
self._collate_fn = dataset.dataset.get_collator() | |||||
self._collate_fn.set_as_numpy(as_numpy) | |||||
if collate_fn is not None and collate_fn is not default_collate: | |||||
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||||
self._collate_fn.add_collator(collate_fn) | |||||
else: | |||||
self._collate_fn = _MultiCollator(collate_fn) | |||||
self.cur_indices_batch = None | |||||
self.as_numpy = as_numpy | |||||
def __getattr__(self, item): | |||||
""" | |||||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||||
:param item: | |||||
:return: | |||||
""" | |||||
try: | |||||
return self.dataset.__getattr__(item) | |||||
except AttributeError as e: | |||||
raise e | |||||
def __iter__(self): | |||||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||||
if len(self._collate_fn.get_collators()) == 0: | |||||
self._collate_fn.add_collator(self.collate_fn) | |||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||||
for indices, data in super().__iter__(): | |||||
self.cur_batch_indices = indices | |||||
yield data | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||||
""" | |||||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | |||||
:param field_names: | |||||
:param val: padding值,默认为0 | |||||
:return: | |||||
""" | |||||
flag = False | |||||
for collator in self._collate_fn.get_collators(): | |||||
if isinstance(collator, AutoCollator): | |||||
flag = True | |||||
break | |||||
if flag is False: | |||||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||||
for field_name in field_names: | |||||
self._collate_fn.set_pad_val(field_name, val=val) | |||||
def set_input(self, *field_names) -> None: | |||||
""" | |||||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||||
:param field_names: | |||||
:return: | |||||
""" | |||||
flag = False | |||||
for collator in self._collate_fn.get_collators(): | |||||
if isinstance(collator, AutoCollator): | |||||
flag = True | |||||
break | |||||
if flag is False: | |||||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||||
self._collate_fn.set_input(*field_names) | |||||
def set_collator(self, collator: Callable) -> None: | |||||
""" | |||||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||||
:param collator: 用户自定义的Callable函数 | |||||
:return: | |||||
""" | |||||
self._collate_fn = _MultiCollator(collator) | |||||
def add_collator(self, collator) -> None: | |||||
""" | |||||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||||
:param collator: | |||||
:return: | |||||
""" | |||||
self._collate_fn.add_collator(collator) | |||||
def get_batch_indices(self) -> List[int]: | |||||
""" | |||||
获取当前数据的idx | |||||
:return: | |||||
""" | |||||
return self.cur_batch_indices | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||||
batch_size: int = 1, | |||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||||
pin_memory: bool = False, drop_last: bool = False, | |||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||||
non_train_batch_size: int = 16, as_numpy: bool = False, | |||||
input_fields: Union[List, str] = None)\ | |||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||||
""" | |||||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | |||||
:param input_fields: | |||||
:param ds_or_db: dataset或者data_bundle | |||||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||||
:param shuffle: 是否打乱数据集 | |||||
:param sampler: sampler实例化对象 | |||||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param pin_memory: | |||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param timeout: | |||||
:param worker_init_fn: | |||||
:param multiprocessing_context: | |||||
:param generator: | |||||
:param prefetch_factor: | |||||
:param persistent_workers: | |||||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||||
:param non_train_batch_size: | |||||
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 | |||||
""" | |||||
# TODO dict, sequence情况下需要提供 | |||||
if isinstance(input_fields, str): | |||||
input_fields = [input_fields] | |||||
if isinstance(ds_or_db, DataSet): | |||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
dl.set_input(*input_fields) | |||||
return dl | |||||
elif isinstance(ds_or_db, DataBundle): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.iter_datasets(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
else: | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
dl_bundle[name].set_input(*input_fields) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Sequence): | |||||
dl_bundle = [] | |||||
for idx, ds in enumerate(ds_or_db): | |||||
if idx == 0: | |||||
dl_bundle.append( | |||||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
) | |||||
else: | |||||
dl_bundle.append( | |||||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
) | |||||
for dl in dl_bundle: | |||||
dl.set_input(*input_fields) | |||||
return dl_bundle | |||||
elif isinstance(ds_or_db, Mapping): | |||||
dl_bundle = {} | |||||
for name, ds in ds_or_db.items(): | |||||
if 'train' in name: | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
else: | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||||
as_numpy=as_numpy) | |||||
dl_bundle[name].set_input(*input_fields) | |||||
return dl_bundle | |||||
else: | |||||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |