From 69102a8427a72ae7c8194f6a5543cce2644f6f61 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Fri, 8 Apr 2022 21:30:35 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86dataloaders?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/__init__.py | 14 + fastNLP/core/dataloaders/fdataloader.py | 7 + .../dataloaders/jittor_dataloader/__init__.py | 7 + .../core/dataloaders/jittor_dataloader/fdl.py | 138 ++++++++ fastNLP/core/dataloaders/mix_dataloader.py | 194 +++++++++++ .../dataloaders/paddle_dataloader/__init__.py | 6 + .../core/dataloaders/paddle_dataloader/fdl.py | 192 +++++++++++ .../dataloaders/torch_dataloader/__init__.py | 6 + .../core/dataloaders/torch_dataloader/fdl.py | 300 ++++++++++++++++++ 9 files changed, 864 insertions(+) create mode 100644 fastNLP/core/dataloaders/__init__.py create mode 100644 fastNLP/core/dataloaders/fdataloader.py create mode 100644 fastNLP/core/dataloaders/jittor_dataloader/__init__.py create mode 100644 fastNLP/core/dataloaders/jittor_dataloader/fdl.py create mode 100644 fastNLP/core/dataloaders/mix_dataloader.py create mode 100644 fastNLP/core/dataloaders/paddle_dataloader/__init__.py create mode 100644 fastNLP/core/dataloaders/paddle_dataloader/fdl.py create mode 100644 fastNLP/core/dataloaders/torch_dataloader/__init__.py create mode 100644 fastNLP/core/dataloaders/torch_dataloader/fdl.py diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py new file mode 100644 index 00000000..40dd7b1c --- /dev/null +++ b/fastNLP/core/dataloaders/__init__.py @@ -0,0 +1,14 @@ +__all__ = [ + 'MixDataLoader', + 'TorchDataLoader', + 'PaddleDataLoader', + 'JittorDataLoader', + 'prepare_jittor_dataloader', + 'prepare_paddle_dataloader', + 'prepare_torch_dataloader' +] + +from .mix_dataloader import MixDataLoader +from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader +from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader +from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader diff --git a/fastNLP/core/dataloaders/fdataloader.py b/fastNLP/core/dataloaders/fdataloader.py new file mode 100644 index 00000000..742f3909 --- /dev/null +++ b/fastNLP/core/dataloaders/fdataloader.py @@ -0,0 +1,7 @@ +__all__ = [ + 'FDataLoader' +] + + +class FDataLoader: + pass diff --git a/fastNLP/core/dataloaders/jittor_dataloader/__init__.py b/fastNLP/core/dataloaders/jittor_dataloader/__init__.py new file mode 100644 index 00000000..8aba7614 --- /dev/null +++ b/fastNLP/core/dataloaders/jittor_dataloader/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "JittorDataLoader", + 'prepare_jittor_dataloader' + +] + +from .fdl import JittorDataLoader, prepare_jittor_dataloader diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py new file mode 100644 index 00000000..2cf85fd8 --- /dev/null +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -0,0 +1,138 @@ +__all__ = [ + 'JittorDataLoader', + 'prepare_jittor_dataloader' +] + +from typing import Callable, Optional, List + +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR +if _NEED_IMPORT_JITTOR: + from jittor.dataset.utils import collate_batch + from jittor.dataset import Dataset +else: + from fastNLP.core.dataset import DataSet as Dataset +from fastNLP.core.utils.jittor_utils import jittor_collate_wraps +from fastNLP.core.collators import AutoCollator +from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataset import DataSet as FDataSet + + +class _JittorDataset(Dataset): + """ + 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset + """ + + def __init__(self, dataset) -> None: + super(_JittorDataset, self).__init__() + self.dataset = dataset + + def __getitem__(self, item): + return (item, self.dataset[item]) + + def __len__(self) -> int: + return len(self.dataset) + + # def __getattr__(self, item): + # # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 + # try: + # self.dataset.__getattribute__(item) + # except Exception as e: + # raise e + + +class JittorDataLoader: + """ + 提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset + """ + + def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, + drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, + stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, + collate_fn: Callable = None) -> None: + """ + + :param dataset: 实现__getitem__和__len__的dataset + :param batch_size: 批次大小 + :param shuffle: 是否打乱数据集 + :param drop_last: 是否去掉最后一个不符合batch_size的数据 + :param num_workers: 进程的数量,当num_workers=0时不开启多进程 + :param buffer_size: + :param stop_grad: + :param keep_numpy_array: + :param endless: + :param collate_fn: 对取得到的数据进行打包的callable函数 + :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 + """ + # TODO 支持fastnlp dataset + # TODO 验证支持replacesampler (以后完成) + # 是否为 jittor 类型的 dataset + + if isinstance(dataset, FDataSet): + collator = dataset.get_collator().set_as_numpy(as_numpy=True) + else: + collator = None + + self.dataset = _JittorDataset(dataset) + + self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, + keep_numpy_array=keep_numpy_array, endless=endless) + if isinstance(self.dataset.dataset, Dataset): + self.dataset.dataset.set_attrs(batch_size=1) + # 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 + self.collate_fn = collate_fn + if self.collate_fn is None: + self.collate_fn = collate_batch + self.auto_collator = collator + self.cur_batch_indices = None + + def __iter__(self): + # TODO 第一次迭代后不能设置collate_fn,设置是无效的 + if self.cur_batch_indices is None: + self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, + self.auto_collator))) + for indices, data in self.dataset.__iter__(): + self.cur_batch_indices = indices + yield data + + def __len__(self): + if self.dataset.drop_last: + return len(self.dataset) // self.dataset.batch_size + return (len(self.dataset) - 1) // self.dataset.batch_size + 1 + + def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: + """ + 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 + 当val=None时,意味着给定的field_names都不需要尝试padding + + :param field_names: + :param val: padding值,默认为0 + :return: + """ + if self.auto_collator is None: + self.auto_collator = AutoCollator(as_numpy=True) + self.auto_collator.set_pad_val(*field_names, val=val) + + def set_input(self, *field_names) -> None: + """ + 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 + + :param field_names: + :return: + """ + if self.auto_collator is None: + self.auto_collator = AutoCollator(as_numpy=True) + + self.auto_collator.set_input(*field_names) + + def get_batch_indices(self) -> List[int]: + """ + 获取当前数据的idx + + :return: + """ + return self.cur_batch_indices + + +def prepare_jittor_dataloader(): + ... diff --git a/fastNLP/core/dataloaders/mix_dataloader.py b/fastNLP/core/dataloaders/mix_dataloader.py new file mode 100644 index 00000000..7ff3eb32 --- /dev/null +++ b/fastNLP/core/dataloaders/mix_dataloader.py @@ -0,0 +1,194 @@ +__all__ = [ + 'MixDataLoader' +] + +from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence + +import numpy as np + +from fastNLP.core.dataset import DataSet, Instance +from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader, Sampler +else: + from fastNLP.core.utils.dummy_class import DummyClass as DataLoader + + +class _MixDataset: + """ + 将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx + + """ + def __init__(self, datasets: list = None) -> None: + """ + + :param datasets: 数据集的列表 + """ + self.datasets = datasets + # 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 + self.lens = [] + index = 0 + for item in self.datasets: + index += len(item) + self.lens.append(index) + + def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: + """ + + :param idx: + :return: + """ + if isinstance(idx, int): + if idx >= self.lens[-1]: + raise ValueError(f"idx: {idx} out of range") + # 找到其属于哪个数据集,返回下标 + ds_index = np.searchsorted(self.lens, idx, side='right') + if ds_index > 0: + idx -= self.lens[ds_index - 1] + return self.datasets[ds_index][idx], ds_index + elif isinstance(idx, list): + # 一般一个list列表只能是属于一种数据的,否则会报错 + dataset = DataSet() + ds_index = 0 + for i in idx: + assert isinstance(i, int), "Only int index allowed." + instance, ds_index = self[i] + dataset.append(instance) + return dataset, ds_index + else: + raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + + def __len__(self) -> int: + return self.lens[-1] + + +class _MixCollateFn: + """ + 存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 + + """ + def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, + auto_collators: Optional[List[Callable]] = None) -> None: + if isinstance(collate_fns, Sequence): + self.collate_fns = lambda idx, lst: collate_fns[idx](lst) + elif callable(collate_fns): + self.collate_fns = lambda idx, lst: collate_fns(lst) + else: + self.collate_fns = lambda idx, lst: lst + + self.collate_fns = collate_fns + self.auto_collators = auto_collators + + def __call__(self, ins_list: List) -> Dict: + """ + 调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 + :param ins_list: + :return: + """ + _ins_list, _ds_index = [], 0 + for ins, _ds_index in ins_list: + _ins_list.append(ins) + # auto_collate先处理 + if self.auto_collators is not None: + _ins_list = self.auto_collators[_ds_index](_ins_list) + _ins_list = self.collate_fns(_ds_index, _ins_list) + return _ins_list + + +class MixDataLoader(DataLoader): + """ + 针对一下三种情况提供的MixDataLoader: + 1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 + 2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 + 3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset + 采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 + """ + def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', + collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, + sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, + num_workers: int = 0, batch_size: int = 16, drop_last=False, + ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, + pin_memory: bool = True) -> None: + """ + + :param datasets: dataset的列表 + :param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, + 当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 + 对象,每次迭代返回的是List[int] + :param collate_fn: 对取得到的数据进行打包的callable函数, + 当其为callable类型时候,所有数据集采样的数据都会经过这个函数; + 当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, + 其对应关系与datasets位置匹配; + 当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 + :param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 + sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; + sampler为List[Sampler],则datasets也为List,且一一对应 + sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 + :param num_workers: 进程的数量,当num_workers=0时不开启多进程 + :param batch_size: 批次大小, datasets的所有数据集batch_size一致 + :param drop_last: 是否去掉最后一个不符合batch_size的数据 + :param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 + 当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 + 当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 + 当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, + 其大于0,可以超过1,将数据集重采样翻倍即可 + 当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 + """ + # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, + if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ + isinstance(sampler, Dict): + raise ValueError(f"") + + if isinstance(collate_fn, list): + if len(collate_fn) != len(datasets): + raise ValueError("the length of collate_fn != datasets!!") + + if isinstance(sampler, list): + if len(sampler) != len(datasets): + raise ValueError("the length of sampler != datasets!!") + + # Dict类型转化为List,以便于_MixCollateFn处理 + if isinstance(collate_fn, Dict): + collate_fn = [fn for _, fn in collate_fn.items()] + + # 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 + if isinstance(datasets, Dict): + dataset = [ds for _, ds in datasets.items()] + else: + dataset = datasets + auto_collators = [] + for per_ds in dataset: + if isinstance(per_ds, DataSet): + auto_collators.append(per_ds.get_collator()) + else: + # 如果没有对应的collator就设置一个不做任何操作的collator + auto_collators.append(lambda x: x) + + # List类型的collate_fn只有两种情况,需要对其进行包裹 + collate_fn = _MixCollateFn(collate_fn, auto_collators) + if mode == 'sequential': + batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, + drop_last=drop_last, ds_ratio=ds_ratio) + elif mode == 'polling': + batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler, + drop_last=drop_last, ds_ratio=ds_ratio) + elif mode == 'mix': + batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler, + drop_last=drop_last, ds_ratio=ds_ratio) + elif isinstance(mode, Sampler): + batch_sampler = mode + else: + raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") + + super(MixDataLoader, self).__init__( + _MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=False, timeout=0, + worker_init_fn=None, multiprocessing_context=None, generator=None, + prefetch_factor=2, persistent_workers=False + ) + + def __iter__(self): + return super().__iter__() diff --git a/fastNLP/core/dataloaders/paddle_dataloader/__init__.py b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py new file mode 100644 index 00000000..ab9523e5 --- /dev/null +++ b/fastNLP/core/dataloaders/paddle_dataloader/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + 'prepare_paddle_dataloader', + 'PaddleDataLoader' +] + +from .fdl import PaddleDataLoader, prepare_paddle_dataloader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py new file mode 100644 index 00000000..b54b9cff --- /dev/null +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -0,0 +1,192 @@ +__all__ = [ + 'PaddleDataLoader', + 'prepare_paddle_dataloader' +] + +from typing import Callable, List, Optional, Union, Dict, Sequence + +from fastNLP.envs.imports import _NEED_IMPORT_PADDLE +if _NEED_IMPORT_PADDLE: + from paddle.io import DataLoader, Dataset + from paddle.fluid.dataloader.collate import default_collate_fn +else: + from fastNLP.core.utils.dummy_class import DummyClass as Dataset + from fastNLP.core.utils.dummy_class import DummyClass as DataLoader + +from fastNLP.core.collators.collator import _MultiCollator +from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.core.dataset import DataSet as FDataSet + + +class _PaddleDataset(Dataset): + """ + 对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader + """ + + def __init__(self, dataset) -> None: + super(_PaddleDataset, self).__init__() + self.dataset = dataset + + def __getitem__(self, item): + return (item, self.dataset[item]) + + def __len__(self) -> int: + return len(self.dataset) + + def __getattr__(self, item): + try: + self.dataset.__getattribute__(item) + except Exception as e: + raise e + + +class PaddleDataLoader(DataLoader): + + def __init__(self, dataset, feed_list=None, places=None, + return_list: bool = True, batch_sampler=None, + batch_size: int = 1, shuffle: bool = False, + drop_last: bool = False, collate_fn: Callable = None, + num_workers: int = 0, use_buffer_reader: bool = True, + use_shared_memory: bool = True, timeout: int = 0, + worker_init_fn: Callable = None, persistent_workers=False) -> None: + + if not isinstance(dataset, _PaddleDataset): + dataset = _PaddleDataset(dataset) + + super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, + return_list=return_list, batch_sampler=batch_sampler, + batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, + collate_fn=None, num_workers=num_workers, + use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, + timeout=timeout, worker_init_fn=worker_init_fn, + persistent_workers=persistent_workers) + if isinstance(dataset.dataset, FDataSet): + self._collate_fn = dataset.dataset.get_collator() + self._collate_fn.set_as_numpy(as_numpy=True) + if collate_fn is not None: + self._collate_fn.add_collator(collate_fn) + else: + self._collate_fn = _MultiCollator(collate_fn) + # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) + # if collate_fn is not None: + # _collate_fn.add_collator(collate_fn) + # self._collate_fn = _collate_fn + self.cur_batch_indices = None + + def __iter__(self): + # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 + if len(self._collate_fn.get_collators()) == 0: + self._collate_fn.add_collator(default_collate_fn) + # self._collate_fn = default_collate_fn + self.collate_fn = indice_collate_wrapper(self._collate_fn) + for indices, data in super().__iter__(): + self.cur_batch_indices = indices + yield data + + def __getattr__(self, item): + """ + 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 + + :param item: + :return: + """ + try: + return self.dataset.__getattr__(item) + except AttributeError as e: + raise e + + def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: + """ + 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 + 当val=None时,意味着给定的field_names都不需要尝试padding + + :param field_names: + :param val: padding值,默认为0 + :return: + """ + for field_name in field_names: + self._collate_fn.set_pad_val(field_name, val=val) + + def set_input(self, *field_names) -> None: + """ + 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 + + :param field_names: + :return: + """ + self._collate_fn.set_input(*field_names) + + def set_collator(self, collator: Callable) -> None: + """ + 设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate + + :param collator: 用户自定义的Callable函数 + :return: + """ + self._collate_fn = _MultiCollator(collator) + + def add_collator(self, collator) -> None: + """ + 添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 + + :param collator: + :return: + """ + self._collate_fn.add_collator(collator) + + def get_batch_indices(self) -> List[int]: + """ + 获取当前数据的idx + + :return: + """ + return self.cur_batch_indices + + +def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, + return_list: bool = True, batch_sampler=None, + train_batch_size: int = 1, shuffle: bool = False, + drop_last: bool = False, collate_fn: Callable = None, + num_workers: int = 0, use_buffer_reader: bool = True, + use_shared_memory: bool = True, timeout: int = 0, + worker_init_fn: Callable = None, persistent_workers=False, + non_train_batch_size: int = 16, + input_fields: Union[List[str], str] = None)\ + -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: + if isinstance(input_fields, str): + input_fields = [input_fields] + + if isinstance(ds_or_db, Dataset): + ... + elif isinstance(ds_or_db, Sequence): + ds_seq = [] + for ds in ds_or_db: + dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + dl.set_input(*input_fields) + ds_seq.append(dl) + return ds_seq + + elif isinstance(ds_or_db, Dict): + ds_dict = {} + for name, ds in ds_or_db.items(): + if 'train' in name: + dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + else: + dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, + batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, + drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, + use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, + timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) + dl.set_input(*input_fields) + ds_dict[name] = dl + return ds_dict + else: + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") diff --git a/fastNLP/core/dataloaders/torch_dataloader/__init__.py b/fastNLP/core/dataloaders/torch_dataloader/__init__.py new file mode 100644 index 00000000..4f3fc707 --- /dev/null +++ b/fastNLP/core/dataloaders/torch_dataloader/__init__.py @@ -0,0 +1,6 @@ +__all__ = [ + "TorchDataLoader", + "prepare_torch_dataloader" +] + +from .fdl import TorchDataLoader, prepare_torch_dataloader diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py new file mode 100644 index 00000000..0cae39ac --- /dev/null +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -0,0 +1,300 @@ +__all__ = [ + 'TorchDataLoader', + 'prepare_torch_dataloader' +] + +from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping + +from fastNLP.core.dataset import DataSet +from fastNLP.core.collators import AutoCollator +from fastNLP.core.collators.collator import _MultiCollator +from fastNLP.core.utils.utils import indice_collate_wrapper +from fastNLP.io.data_bundle import DataBundle +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader, Sampler + from torch.utils.data._utils.collate import default_collate +else: + from ..fdataloader import FDataLoader as DataLoader + + +class _FDataSet: + """ + 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset + 中调用dataset的方法 + """ + def __init__(self, dataset) -> None: + self.dataset = dataset + + def __getitem__(self, item: Union[int, list]) -> Tuple: + return (item, self.dataset[item]) + + def __getattr__(self, item): + try: + return self.dataset.__getattribute__(item) + except AttributeError as e: + raise e + + def __len__(self) -> int: + return len(self.dataset) + + +class TorchDataLoader(DataLoader): + """ + 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 + 提供的方法调节设置collate_fn的若干参数。 + """ + def __init__(self, dataset, batch_size: int = 1, + shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, + batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + num_workers: int = 0, collate_fn: Optional[Callable] = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, as_numpy: bool = False) -> None: + """ + + :param dataset: 实现了__getitem__和__len__的数据容器 + :param batch_size: 批次大小,当batch_sampler为None生效 + :param shuffle: 是否打乱数据集 + :param sampler: sampler实例化对象 + :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 + :param num_workers: 进程的数量,当num_worker=0时不开启多进程 + :param collate_fn: 对取得到的数据进行打包的callable函数 + :param pin_memory: + :param drop_last: 是否去掉最后一个不符合batch_size的数据 + :param timeout: + :param worker_init_fn: + :param multiprocessing_context: + :param generator: + :param prefetch_factor: + :param persistent_workers: + :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 + """ + if not isinstance(dataset, _FDataSet): + dataset = _FDataSet(dataset) + + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, + pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset + self._collate_fn = dataset.dataset.get_collator() + self._collate_fn.set_as_numpy(as_numpy) + if collate_fn is not None and collate_fn is not default_collate: + # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 + self._collate_fn.add_collator(collate_fn) + else: + self._collate_fn = _MultiCollator(collate_fn) + + self.cur_indices_batch = None + self.as_numpy = as_numpy + + def __getattr__(self, item): + """ + 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 + :param item: + :return: + """ + try: + return self.dataset.__getattr__(item) + except AttributeError as e: + raise e + + def __iter__(self): + # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 + if len(self._collate_fn.get_collators()) == 0: + self._collate_fn.add_collator(self.collate_fn) + self.collate_fn = indice_collate_wrapper(self._collate_fn) + for indices, data in super().__iter__(): + self.cur_batch_indices = indices + yield data + + def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: + """ + 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 + 当val=None时,意味着给定的field_names都不需要尝试padding + :param field_names: + :param val: padding值,默认为0 + :return: + """ + flag = False + for collator in self._collate_fn.get_collators(): + if isinstance(collator, AutoCollator): + flag = True + break + if flag is False: + self._collate_fn.add_collator(AutoCollator(self.as_numpy)) + for field_name in field_names: + self._collate_fn.set_pad_val(field_name, val=val) + + def set_input(self, *field_names) -> None: + """ + 被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 + + :param field_names: + :return: + """ + flag = False + for collator in self._collate_fn.get_collators(): + if isinstance(collator, AutoCollator): + flag = True + break + if flag is False: + self._collate_fn.add_collator(AutoCollator(self.as_numpy)) + self._collate_fn.set_input(*field_names) + + def set_collator(self, collator: Callable) -> None: + """ + 设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate + + :param collator: 用户自定义的Callable函数 + :return: + """ + self._collate_fn = _MultiCollator(collator) + + def add_collator(self, collator) -> None: + """ + 添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 + + :param collator: + :return: + """ + self._collate_fn.add_collator(collator) + + def get_batch_indices(self) -> List[int]: + """ + 获取当前数据的idx + + :return: + """ + return self.cur_batch_indices + + +def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], + batch_size: int = 1, + shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, + batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + num_workers: int = 0, collate_fn: Optional[Callable] = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, + non_train_batch_size: int = 16, as_numpy: bool = False, + input_fields: Union[List, str] = None)\ + -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: + """ + 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 + + :param input_fields: + :param ds_or_db: dataset或者data_bundle + :param batch_size: 批次大小,当batch_sampler为None生效 + :param shuffle: 是否打乱数据集 + :param sampler: sampler实例化对象 + :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 + :param num_workers: 进程的数量,当num_worker=0时不开启多进程 + :param collate_fn: 对取得到的数据进行打包的callable函数 + :param pin_memory: + :param drop_last: 是否去掉最后一个不符合batch_size的数据 + :param timeout: + :param worker_init_fn: + :param multiprocessing_context: + :param generator: + :param prefetch_factor: + :param persistent_workers: + :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler + :param non_train_batch_size: + :param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 + """ + # TODO dict, sequence情况下需要提供 + if isinstance(input_fields, str): + input_fields = [input_fields] + + if isinstance(ds_or_db, DataSet): + dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + dl.set_input(*input_fields) + return dl + + elif isinstance(ds_or_db, DataBundle): + dl_bundle = {} + for name, ds in ds_or_db.iter_datasets(): + if 'train' in name: + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + else: + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, + shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + dl_bundle[name].set_input(*input_fields) + return dl_bundle + + elif isinstance(ds_or_db, Sequence): + dl_bundle = [] + for idx, ds in enumerate(ds_or_db): + if idx == 0: + dl_bundle.append( + TorchDataLoader(dataset=ds, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + ) + else: + dl_bundle.append( + TorchDataLoader(dataset=ds, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + ) + for dl in dl_bundle: + dl.set_input(*input_fields) + return dl_bundle + + elif isinstance(ds_or_db, Mapping): + dl_bundle = {} + for name, ds in ds_or_db.items(): + if 'train' in name: + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + else: + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, + shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + as_numpy=as_numpy) + + dl_bundle[name].set_input(*input_fields) + + return dl_bundle + else: + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")