Browse Source

添加了dataloaders

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
69102a8427
9 changed files with 864 additions and 0 deletions
  1. +14
    -0
      fastNLP/core/dataloaders/__init__.py
  2. +7
    -0
      fastNLP/core/dataloaders/fdataloader.py
  3. +7
    -0
      fastNLP/core/dataloaders/jittor_dataloader/__init__.py
  4. +138
    -0
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  5. +194
    -0
      fastNLP/core/dataloaders/mix_dataloader.py
  6. +6
    -0
      fastNLP/core/dataloaders/paddle_dataloader/__init__.py
  7. +192
    -0
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  8. +6
    -0
      fastNLP/core/dataloaders/torch_dataloader/__init__.py
  9. +300
    -0
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 14
- 0
fastNLP/core/dataloaders/__init__.py View File

@@ -0,0 +1,14 @@
__all__ = [
'MixDataLoader',
'TorchDataLoader',
'PaddleDataLoader',
'JittorDataLoader',
'prepare_jittor_dataloader',
'prepare_paddle_dataloader',
'prepare_torch_dataloader'
]

from .mix_dataloader import MixDataLoader
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader

+ 7
- 0
fastNLP/core/dataloaders/fdataloader.py View File

@@ -0,0 +1,7 @@
__all__ = [
'FDataLoader'
]


class FDataLoader:
pass

+ 7
- 0
fastNLP/core/dataloaders/jittor_dataloader/__init__.py View File

@@ -0,0 +1,7 @@
__all__ = [
"JittorDataLoader",
'prepare_jittor_dataloader'

]

from .fdl import JittorDataLoader, prepare_jittor_dataloader

+ 138
- 0
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -0,0 +1,138 @@
__all__ = [
'JittorDataLoader',
'prepare_jittor_dataloader'
]

from typing import Callable, Optional, List

from fastNLP.envs.imports import _NEED_IMPORT_JITTOR
if _NEED_IMPORT_JITTOR:
from jittor.dataset.utils import collate_batch
from jittor.dataset import Dataset
else:
from fastNLP.core.dataset import DataSet as Dataset
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps
from fastNLP.core.collators import AutoCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet


class _JittorDataset(Dataset):
"""
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset
"""

def __init__(self, dataset) -> None:
super(_JittorDataset, self).__init__()
self.dataset = dataset

def __getitem__(self, item):
return (item, self.dataset[item])

def __len__(self) -> int:
return len(self.dataset)

# def __getattr__(self, item):
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用
# try:
# self.dataset.__getattribute__(item)
# except Exception as e:
# raise e


class JittorDataLoader:
"""
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset
"""

def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Callable = None) -> None:
"""

:param dataset: 实现__getitem__和__len__的dataset
:param batch_size: 批次大小
:param shuffle: 是否打乱数据集
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param num_workers: 进程的数量,当num_workers=0时不开启多进程
:param buffer_size:
:param stop_grad:
:param keep_numpy_array:
:param endless:
:param collate_fn: 对取得到的数据进行打包的callable函数
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
"""
# TODO 支持fastnlp dataset
# TODO 验证支持replacesampler (以后完成)
# 是否为 jittor 类型的 dataset

if isinstance(dataset, FDataSet):
collator = dataset.get_collator().set_as_numpy(as_numpy=True)
else:
collator = None

self.dataset = _JittorDataset(dataset)

self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad,
keep_numpy_array=keep_numpy_array, endless=endless)
if isinstance(self.dataset.dataset, Dataset):
self.dataset.dataset.set_attrs(batch_size=1)
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数
self.collate_fn = collate_fn
if self.collate_fn is None:
self.collate_fn = collate_batch
self.auto_collator = collator
self.cur_batch_indices = None

def __iter__(self):
# TODO 第一次迭代后不能设置collate_fn,设置是无效的
if self.cur_batch_indices is None:
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn,
self.auto_collator)))
for indices, data in self.dataset.__iter__():
self.cur_batch_indices = indices
yield data

def __len__(self):
if self.dataset.drop_last:
return len(self.dataset) // self.dataset.batch_size
return (len(self.dataset) - 1) // self.dataset.batch_size + 1

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names:
:param val: padding值,默认为0
:return:
"""
if self.auto_collator is None:
self.auto_collator = AutoCollator(as_numpy=True)
self.auto_collator.set_pad_val(*field_names, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
if self.auto_collator is None:
self.auto_collator = AutoCollator(as_numpy=True)

self.auto_collator.set_input(*field_names)

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx

:return:
"""
return self.cur_batch_indices


def prepare_jittor_dataloader():
...

+ 194
- 0
fastNLP/core/dataloaders/mix_dataloader.py View File

@@ -0,0 +1,194 @@
__all__ = [
'MixDataLoader'
]

from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence

import numpy as np

from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader


class _MixDataset:
"""
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx

"""
def __init__(self, datasets: list = None) -> None:
"""

:param datasets: 数据集的列表
"""
self.datasets = datasets
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置
self.lens = []
index = 0
for item in self.datasets:
index += len(item)
self.lens.append(index)

def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]:
"""

:param idx:
:return:
"""
if isinstance(idx, int):
if idx >= self.lens[-1]:
raise ValueError(f"idx: {idx} out of range")
# 找到其属于哪个数据集,返回下标
ds_index = np.searchsorted(self.lens, idx, side='right')
if ds_index > 0:
idx -= self.lens[ds_index - 1]
return self.datasets[ds_index][idx], ds_index
elif isinstance(idx, list):
# 一般一个list列表只能是属于一种数据的,否则会报错
dataset = DataSet()
ds_index = 0
for i in idx:
assert isinstance(i, int), "Only int index allowed."
instance, ds_index = self[i]
dataset.append(instance)
return dataset, ds_index
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __len__(self) -> int:
return self.lens[-1]


class _MixCollateFn:
"""
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题

"""
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None,
auto_collators: Optional[List[Callable]] = None) -> None:
if isinstance(collate_fns, Sequence):
self.collate_fns = lambda idx, lst: collate_fns[idx](lst)
elif callable(collate_fns):
self.collate_fns = lambda idx, lst: collate_fns(lst)
else:
self.collate_fns = lambda idx, lst: lst

self.collate_fns = collate_fns
self.auto_collators = auto_collators

def __call__(self, ins_list: List) -> Dict:
"""
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种
:param ins_list:
:return:
"""
_ins_list, _ds_index = [], 0
for ins, _ds_index in ins_list:
_ins_list.append(ins)
# auto_collate先处理
if self.auto_collators is not None:
_ins_list = self.auto_collators[_ds_index](_ins_list)
_ins_list = self.collate_fns(_ds_index, _ins_list)
return _ins_list


class MixDataLoader(DataLoader):
"""
针对一下三种情况提供的MixDataLoader:
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。
"""
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential',
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None,
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None,
num_workers: int = 0, batch_size: int = 16, drop_last=False,
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None,
pin_memory: bool = True) -> None:
"""

:param datasets: dataset的列表
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况,
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代
对象,每次迭代返回的是List[int]
:param collate_fn: 对取得到的数据进行打包的callable函数,
当其为callable类型时候,所有数据集采样的数据都会经过这个函数;
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数,
其对应关系与datasets位置匹配;
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样;
sampler为List[Sampler],则datasets也为List,且一一对应
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。
:param num_workers: 进程的数量,当num_workers=0时不开启多进程
:param batch_size: 批次大小, datasets的所有数据集batch_size一致
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数,
其大于0,可以超过1,将数据集重采样翻倍即可
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。
"""
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable,
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \
isinstance(sampler, Dict):
raise ValueError(f"")

if isinstance(collate_fn, list):
if len(collate_fn) != len(datasets):
raise ValueError("the length of collate_fn != datasets!!")

if isinstance(sampler, list):
if len(sampler) != len(datasets):
raise ValueError("the length of sampler != datasets!!")

# Dict类型转化为List,以便于_MixCollateFn处理
if isinstance(collate_fn, Dict):
collate_fn = [fn for _, fn in collate_fn.items()]

# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测
if isinstance(datasets, Dict):
dataset = [ds for _, ds in datasets.items()]
else:
dataset = datasets
auto_collators = []
for per_ds in dataset:
if isinstance(per_ds, DataSet):
auto_collators.append(per_ds.get_collator())
else:
# 如果没有对应的collator就设置一个不做任何操作的collator
auto_collators.append(lambda x: x)

# List类型的collate_fn只有两种情况,需要对其进行包裹
collate_fn = _MixCollateFn(collate_fn, auto_collators)
if mode == 'sequential':
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler,
drop_last=drop_last, ds_ratio=ds_ratio)
elif mode == 'polling':
batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler,
drop_last=drop_last, ds_ratio=ds_ratio)
elif mode == 'mix':
batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler,
drop_last=drop_last, ds_ratio=ds_ratio)
elif isinstance(mode, Sampler):
batch_sampler = mode
else:
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler")

super(MixDataLoader, self).__init__(
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None, generator=None,
prefetch_factor=2, persistent_workers=False
)

def __iter__(self):
return super().__iter__()

+ 6
- 0
fastNLP/core/dataloaders/paddle_dataloader/__init__.py View File

@@ -0,0 +1,6 @@
__all__ = [
'prepare_paddle_dataloader',
'PaddleDataLoader'
]

from .fdl import PaddleDataLoader, prepare_paddle_dataloader

+ 192
- 0
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -0,0 +1,192 @@
__all__ = [
'PaddleDataLoader',
'prepare_paddle_dataloader'
]

from typing import Callable, List, Optional, Union, Dict, Sequence

from fastNLP.envs.imports import _NEED_IMPORT_PADDLE
if _NEED_IMPORT_PADDLE:
from paddle.io import DataLoader, Dataset
from paddle.fluid.dataloader.collate import default_collate_fn
else:
from fastNLP.core.utils.dummy_class import DummyClass as Dataset
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader

from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.core.dataset import DataSet as FDataSet


class _PaddleDataset(Dataset):
"""
对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader
"""

def __init__(self, dataset) -> None:
super(_PaddleDataset, self).__init__()
self.dataset = dataset

def __getitem__(self, item):
return (item, self.dataset[item])

def __len__(self) -> int:
return len(self.dataset)

def __getattr__(self, item):
try:
self.dataset.__getattribute__(item)
except Exception as e:
raise e


class PaddleDataLoader(DataLoader):

def __init__(self, dataset, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Callable = None,
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False) -> None:

if not isinstance(dataset, _PaddleDataset):
dataset = _PaddleDataset(dataset)

super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places,
return_list=return_list, batch_sampler=batch_sampler,
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last,
collate_fn=None, num_workers=num_workers,
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory,
timeout=timeout, worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers)
if isinstance(dataset.dataset, FDataSet):
self._collate_fn = dataset.dataset.get_collator()
self._collate_fn.set_as_numpy(as_numpy=True)
if collate_fn is not None:
self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = _MultiCollator(collate_fn)
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True))
# if collate_fn is not None:
# _collate_fn.add_collator(collate_fn)
# self._collate_fn = _collate_fn
self.cur_batch_indices = None

def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
if len(self._collate_fn.get_collators()) == 0:
self._collate_fn.add_collator(default_collate_fn)
# self._collate_fn = default_collate_fn
self.collate_fn = indice_collate_wrapper(self._collate_fn)
for indices, data in super().__iter__():
self.cur_batch_indices = indices
yield data

def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法

:param item:
:return:
"""
try:
return self.dataset.__getattr__(item)
except AttributeError as e:
raise e

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names:
:param val: padding值,默认为0
:return:
"""
for field_name in field_names:
self._collate_fn.set_pad_val(field_name, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
self._collate_fn.set_input(*field_names)

def set_collator(self, collator: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collator: 用户自定义的Callable函数
:return:
"""
self._collate_fn = _MultiCollator(collator)

def add_collator(self, collator) -> None:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面

:param collator:
:return:
"""
self._collate_fn.add_collator(collator)

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx

:return:
"""
return self.cur_batch_indices


def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
train_batch_size: int = 1, shuffle: bool = False,
drop_last: bool = False, collate_fn: Callable = None,
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,
worker_init_fn: Callable = None, persistent_workers=False,
non_train_batch_size: int = 16,
input_fields: Union[List[str], str] = None)\
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]:
if isinstance(input_fields, str):
input_fields = [input_fields]

if isinstance(ds_or_db, Dataset):
...
elif isinstance(ds_or_db, Sequence):
ds_seq = []
for ds in ds_or_db:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
dl.set_input(*input_fields)
ds_seq.append(dl)
return ds_seq

elif isinstance(ds_or_db, Dict):
ds_dict = {}
for name, ds in ds_or_db.items():
if 'train' in name:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
else:
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list,
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle,
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers,
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader,
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers)
dl.set_input(*input_fields)
ds_dict[name] = dl
return ds_dict
else:
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")

+ 6
- 0
fastNLP/core/dataloaders/torch_dataloader/__init__.py View File

@@ -0,0 +1,6 @@
__all__ = [
"TorchDataLoader",
"prepare_torch_dataloader"
]

from .fdl import TorchDataLoader, prepare_torch_dataloader

+ 300
- 0
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -0,0 +1,300 @@
__all__ = [
'TorchDataLoader',
'prepare_torch_dataloader'
]

from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping

from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.utils.utils import indice_collate_wrapper
from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
from torch.utils.data._utils.collate import default_collate
else:
from ..fdataloader import FDataLoader as DataLoader


class _FDataSet:
"""
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset
中调用dataset的方法
"""
def __init__(self, dataset) -> None:
self.dataset = dataset

def __getitem__(self, item: Union[int, list]) -> Tuple:
return (item, self.dataset[item])

def __getattr__(self, item):
try:
return self.dataset.__getattribute__(item)
except AttributeError as e:
raise e

def __len__(self) -> int:
return len(self.dataset)


class TorchDataLoader(DataLoader):
"""
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过
提供的方法调节设置collate_fn的若干参数。
"""
def __init__(self, dataset, batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, as_numpy: bool = False) -> None:
"""

:param dataset: 实现了__getitem__和__len__的数据容器
:param batch_size: 批次大小,当batch_sampler为None生效
:param shuffle: 是否打乱数据集
:param sampler: sampler实例化对象
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据
:param num_workers: 进程的数量,当num_worker=0时不开启多进程
:param collate_fn: 对取得到的数据进行打包的callable函数
:param pin_memory:
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param timeout:
:param worker_init_fn:
:param multiprocessing_context:
:param generator:
:param prefetch_factor:
:param persistent_workers:
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型
"""
if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset)

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.get_collator()
self._collate_fn.set_as_numpy(as_numpy)
if collate_fn is not None and collate_fn is not default_collate:
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来
self._collate_fn.add_collator(collate_fn)
else:
self._collate_fn = _MultiCollator(collate_fn)

self.cur_indices_batch = None
self.as_numpy = as_numpy

def __getattr__(self, item):
"""
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法
:param item:
:return:
"""
try:
return self.dataset.__getattr__(item)
except AttributeError as e:
raise e

def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
if len(self._collate_fn.get_collators()) == 0:
self._collate_fn.add_collator(self.collate_fn)
self.collate_fn = indice_collate_wrapper(self._collate_fn)
for indices, data in super().__iter__():
self.cur_batch_indices = indices
yield data

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding
:param field_names:
:param val: padding值,默认为0
:return:
"""
flag = False
for collator in self._collate_fn.get_collators():
if isinstance(collator, AutoCollator):
flag = True
break
if flag is False:
self._collate_fn.add_collator(AutoCollator(self.as_numpy))
for field_name in field_names:
self._collate_fn.set_pad_val(field_name, val=val)

def set_input(self, *field_names) -> None:
"""
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉

:param field_names:
:return:
"""
flag = False
for collator in self._collate_fn.get_collators():
if isinstance(collator, AutoCollator):
flag = True
break
if flag is False:
self._collate_fn.add_collator(AutoCollator(self.as_numpy))
self._collate_fn.set_input(*field_names)

def set_collator(self, collator: Callable) -> None:
"""
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate

:param collator: 用户自定义的Callable函数
:return:
"""
self._collate_fn = _MultiCollator(collator)

def add_collator(self, collator) -> None:
"""
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面

:param collator:
:return:
"""
self._collate_fn.add_collator(collator)

def get_batch_indices(self) -> List[int]:
"""
获取当前数据的idx

:return:
"""
return self.cur_batch_indices


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None,
batch_sampler: Optional["Sampler[Sequence[int]]"] = None,
num_workers: int = 0, collate_fn: Optional[Callable] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None,
non_train_batch_size: int = 16, as_numpy: bool = False,
input_fields: Union[List, str] = None)\
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]:
"""
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象

:param input_fields:
:param ds_or_db: dataset或者data_bundle
:param batch_size: 批次大小,当batch_sampler为None生效
:param shuffle: 是否打乱数据集
:param sampler: sampler实例化对象
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据
:param num_workers: 进程的数量,当num_worker=0时不开启多进程
:param collate_fn: 对取得到的数据进行打包的callable函数
:param pin_memory:
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param timeout:
:param worker_init_fn:
:param multiprocessing_context:
:param generator:
:param prefetch_factor:
:param persistent_workers:
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size:
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。
"""
# TODO dict, sequence情况下需要提供
if isinstance(input_fields, str):
input_fields = [input_fields]

if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
dl.set_input(*input_fields)
return dl

elif isinstance(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
dl_bundle[name].set_input(*input_fields)
return dl_bundle

elif isinstance(ds_or_db, Sequence):
dl_bundle = []
for idx, ds in enumerate(ds_or_db):
if idx == 0:
dl_bundle.append(
TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
else:
dl_bundle.append(
TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
)
for dl in dl_bundle:
dl.set_input(*input_fields)
return dl_bundle

elif isinstance(ds_or_db, Mapping):
dl_bundle = {}
for name, ds in ds_or_db.items():
if 'train' in name:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)
else:
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size,
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler,
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers,
as_numpy=as_numpy)

dl_bundle[name].set_input(*input_fields)

return dl_bundle
else:
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!")

Loading…
Cancel
Save