@@ -0,0 +1,14 @@ | |||
__all__ = [ | |||
'MixDataLoader', | |||
'TorchDataLoader', | |||
'PaddleDataLoader', | |||
'JittorDataLoader', | |||
'prepare_jittor_dataloader', | |||
'prepare_paddle_dataloader', | |||
'prepare_torch_dataloader' | |||
] | |||
from .mix_dataloader import MixDataLoader | |||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,7 @@ | |||
__all__ = [ | |||
'FDataLoader' | |||
] | |||
class FDataLoader: | |||
pass |
@@ -0,0 +1,7 @@ | |||
__all__ = [ | |||
"JittorDataLoader", | |||
'prepare_jittor_dataloader' | |||
] | |||
from .fdl import JittorDataLoader, prepare_jittor_dataloader |
@@ -0,0 +1,138 @@ | |||
__all__ = [ | |||
'JittorDataLoader', | |||
'prepare_jittor_dataloader' | |||
] | |||
from typing import Callable, Optional, List | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
from jittor.dataset.utils import collate_batch | |||
from jittor.dataset import Dataset | |||
else: | |||
from fastNLP.core.dataset import DataSet as Dataset | |||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
class _JittorDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset | |||
""" | |||
def __init__(self, dataset) -> None: | |||
super(_JittorDataset, self).__init__() | |||
self.dataset = dataset | |||
def __getitem__(self, item): | |||
return (item, self.dataset[item]) | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
# def __getattr__(self, item): | |||
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 | |||
# try: | |||
# self.dataset.__getattribute__(item) | |||
# except Exception as e: | |||
# raise e | |||
class JittorDataLoader: | |||
""" | |||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | |||
""" | |||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
collate_fn: Callable = None) -> None: | |||
""" | |||
:param dataset: 实现__getitem__和__len__的dataset | |||
:param batch_size: 批次大小 | |||
:param shuffle: 是否打乱数据集 | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
:param buffer_size: | |||
:param stop_grad: | |||
:param keep_numpy_array: | |||
:param endless: | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||
""" | |||
# TODO 支持fastnlp dataset | |||
# TODO 验证支持replacesampler (以后完成) | |||
# 是否为 jittor 类型的 dataset | |||
if isinstance(dataset, FDataSet): | |||
collator = dataset.get_collator().set_as_numpy(as_numpy=True) | |||
else: | |||
collator = None | |||
self.dataset = _JittorDataset(dataset) | |||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | |||
keep_numpy_array=keep_numpy_array, endless=endless) | |||
if isinstance(self.dataset.dataset, Dataset): | |||
self.dataset.dataset.set_attrs(batch_size=1) | |||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||
self.collate_fn = collate_fn | |||
if self.collate_fn is None: | |||
self.collate_fn = collate_batch | |||
self.auto_collator = collator | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||
if self.cur_batch_indices is None: | |||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, | |||
self.auto_collator))) | |||
for indices, data in self.dataset.__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def __len__(self): | |||
if self.dataset.drop_last: | |||
return len(self.dataset) // self.dataset.batch_size | |||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_pad_val(*field_names, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_input(*field_names) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_jittor_dataloader(): | |||
... |
@@ -0,0 +1,194 @@ | |||
__all__ = [ | |||
'MixDataLoader' | |||
] | |||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
class _MixDataset: | |||
""" | |||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||
""" | |||
def __init__(self, datasets: list = None) -> None: | |||
""" | |||
:param datasets: 数据集的列表 | |||
""" | |||
self.datasets = datasets | |||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | |||
self.lens = [] | |||
index = 0 | |||
for item in self.datasets: | |||
index += len(item) | |||
self.lens.append(index) | |||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | |||
""" | |||
:param idx: | |||
:return: | |||
""" | |||
if isinstance(idx, int): | |||
if idx >= self.lens[-1]: | |||
raise ValueError(f"idx: {idx} out of range") | |||
# 找到其属于哪个数据集,返回下标 | |||
ds_index = np.searchsorted(self.lens, idx, side='right') | |||
if ds_index > 0: | |||
idx -= self.lens[ds_index - 1] | |||
return self.datasets[ds_index][idx], ds_index | |||
elif isinstance(idx, list): | |||
# 一般一个list列表只能是属于一种数据的,否则会报错 | |||
dataset = DataSet() | |||
ds_index = 0 | |||
for i in idx: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance, ds_index = self[i] | |||
dataset.append(instance) | |||
return dataset, ds_index | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __len__(self) -> int: | |||
return self.lens[-1] | |||
class _MixCollateFn: | |||
""" | |||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | |||
""" | |||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||
auto_collators: Optional[List[Callable]] = None) -> None: | |||
if isinstance(collate_fns, Sequence): | |||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | |||
elif callable(collate_fns): | |||
self.collate_fns = lambda idx, lst: collate_fns(lst) | |||
else: | |||
self.collate_fns = lambda idx, lst: lst | |||
self.collate_fns = collate_fns | |||
self.auto_collators = auto_collators | |||
def __call__(self, ins_list: List) -> Dict: | |||
""" | |||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | |||
:param ins_list: | |||
:return: | |||
""" | |||
_ins_list, _ds_index = [], 0 | |||
for ins, _ds_index in ins_list: | |||
_ins_list.append(ins) | |||
# auto_collate先处理 | |||
if self.auto_collators is not None: | |||
_ins_list = self.auto_collators[_ds_index](_ins_list) | |||
_ins_list = self.collate_fns(_ds_index, _ins_list) | |||
return _ins_list | |||
class MixDataLoader(DataLoader): | |||
""" | |||
针对一下三种情况提供的MixDataLoader: | |||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||
""" | |||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||
pin_memory: bool = True) -> None: | |||
""" | |||
:param datasets: dataset的列表 | |||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||
对象,每次迭代返回的是List[int] | |||
:param collate_fn: 对取得到的数据进行打包的callable函数, | |||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||
其对应关系与datasets位置匹配; | |||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||
sampler为List[Sampler],则datasets也为List,且一一对应 | |||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||
其大于0,可以超过1,将数据集重采样翻倍即可 | |||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||
""" | |||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||
isinstance(sampler, Dict): | |||
raise ValueError(f"") | |||
if isinstance(collate_fn, list): | |||
if len(collate_fn) != len(datasets): | |||
raise ValueError("the length of collate_fn != datasets!!") | |||
if isinstance(sampler, list): | |||
if len(sampler) != len(datasets): | |||
raise ValueError("the length of sampler != datasets!!") | |||
# Dict类型转化为List,以便于_MixCollateFn处理 | |||
if isinstance(collate_fn, Dict): | |||
collate_fn = [fn for _, fn in collate_fn.items()] | |||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||
if isinstance(datasets, Dict): | |||
dataset = [ds for _, ds in datasets.items()] | |||
else: | |||
dataset = datasets | |||
auto_collators = [] | |||
for per_ds in dataset: | |||
if isinstance(per_ds, DataSet): | |||
auto_collators.append(per_ds.get_collator()) | |||
else: | |||
# 如果没有对应的collator就设置一个不做任何操作的collator | |||
auto_collators.append(lambda x: x) | |||
# List类型的collate_fn只有两种情况,需要对其进行包裹 | |||
collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||
if mode == 'sequential': | |||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif mode == 'polling': | |||
batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif mode == 'mix': | |||
batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif isinstance(mode, Sampler): | |||
batch_sampler = mode | |||
else: | |||
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | |||
super(MixDataLoader, self).__init__( | |||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||
prefetch_factor=2, persistent_workers=False | |||
) | |||
def __iter__(self): | |||
return super().__iter__() |
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
'prepare_paddle_dataloader', | |||
'PaddleDataLoader' | |||
] | |||
from .fdl import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,192 @@ | |||
__all__ = [ | |||
'PaddleDataLoader', | |||
'prepare_paddle_dataloader' | |||
] | |||
from typing import Callable, List, Optional, Union, Dict, Sequence | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
from paddle.io import DataLoader, Dataset | |||
from paddle.fluid.dataloader.collate import default_collate_fn | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
class _PaddleDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader | |||
""" | |||
def __init__(self, dataset) -> None: | |||
super(_PaddleDataset, self).__init__() | |||
self.dataset = dataset | |||
def __getitem__(self, item): | |||
return (item, self.dataset[item]) | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
def __getattr__(self, item): | |||
try: | |||
self.dataset.__getattribute__(item) | |||
except Exception as e: | |||
raise e | |||
class PaddleDataLoader(DataLoader): | |||
def __init__(self, dataset, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | |||
if not isinstance(dataset, _PaddleDataset): | |||
dataset = _PaddleDataset(dataset) | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=None, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, FDataSet): | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy=True) | |||
if collate_fn is not None: | |||
self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
# if collate_fn is not None: | |||
# _collate_fn.add_collator(collate_fn) | |||
# self._collate_fn = _collate_fn | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(default_collate_fn) | |||
# self._collate_fn = default_collate_fn | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
def add_collator(self, collator) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
train_batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = 16, | |||
input_fields: Union[List[str], str] = None)\ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
if isinstance(ds_or_db, Dataset): | |||
... | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for ds in ds_or_db: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
ds_seq.append(dl) | |||
return ds_seq | |||
elif isinstance(ds_or_db, Dict): | |||
ds_dict = {} | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
else: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
ds_dict[name] = dl | |||
return ds_dict | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
"TorchDataLoader", | |||
"prepare_torch_dataloader" | |||
] | |||
from .fdl import TorchDataLoader, prepare_torch_dataloader |
@@ -0,0 +1,300 @@ | |||
__all__ = [ | |||
'TorchDataLoader', | |||
'prepare_torch_dataloader' | |||
] | |||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
from torch.utils.data._utils.collate import default_collate | |||
else: | |||
from ..fdataloader import FDataLoader as DataLoader | |||
class _FDataSet: | |||
""" | |||
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset | |||
中调用dataset的方法 | |||
""" | |||
def __init__(self, dataset) -> None: | |||
self.dataset = dataset | |||
def __getitem__(self, item: Union[int, list]) -> Tuple: | |||
return (item, self.dataset[item]) | |||
def __getattr__(self, item): | |||
try: | |||
return self.dataset.__getattribute__(item) | |||
except AttributeError as e: | |||
raise e | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
class TorchDataLoader(DataLoader): | |||
""" | |||
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 | |||
提供的方法调节设置collate_fn的若干参数。 | |||
""" | |||
def __init__(self, dataset, batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||
""" | |||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||
:param shuffle: 是否打乱数据集 | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
:param worker_init_fn: | |||
:param multiprocessing_context: | |||
:param generator: | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||
""" | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy) | |||
if collate_fn is not None and collate_fn is not default_collate: | |||
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||
self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
self.cur_indices_batch = None | |||
self.as_numpy = as_numpy | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(self.collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
def add_collator(self, collator) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||
batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||
non_train_batch_size: int = 16, as_numpy: bool = False, | |||
input_fields: Union[List, str] = None)\ | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
""" | |||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | |||
:param input_fields: | |||
:param ds_or_db: dataset或者data_bundle | |||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||
:param shuffle: 是否打乱数据集 | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
:param worker_init_fn: | |||
:param multiprocessing_context: | |||
:param generator: | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 | |||
""" | |||
# TODO dict, sequence情况下需要提供 | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl.set_input(*input_fields) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl_bundle[name].set_input(*input_fields) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
dl_bundle = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx == 0: | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
else: | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
for dl in dl_bundle: | |||
dl.set_input(*input_fields) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Mapping): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl_bundle[name].set_input(*input_fields) | |||
return dl_bundle | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |