@@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||
class _JittorDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | |||
""" | |||
def __init__(self, dataset) -> None: | |||
@@ -37,7 +37,7 @@ class _JittorDataset(Dataset): | |||
item = item.tolist() | |||
return (item, self.dataset[item]) | |||
class JittorDataLoader: | |||
""" | |||
提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | |||
@@ -2,13 +2,14 @@ __all__ = [ | |||
'MixDataLoader' | |||
] | |||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.collators import Collator | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
@@ -18,12 +19,13 @@ else: | |||
class _MixDataset: | |||
""" | |||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||
将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index | |||
""" | |||
def __init__(self, datasets: list = None) -> None: | |||
""" | |||
:param datasets: 数据集的列表 | |||
:param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列 | |||
""" | |||
self.datasets = datasets | |||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | |||
@@ -35,7 +37,7 @@ class _MixDataset: | |||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | |||
""" | |||
根据index索引获取数据 | |||
根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回 | |||
:param idx: 整数类型的index或者列表 | |||
:return: | |||
@@ -69,8 +71,9 @@ class _MixCollateFn: | |||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | |||
""" | |||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||
auto_collators: Optional[List[Callable]] = None) -> None: | |||
def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None: | |||
if isinstance(collate_fns, Sequence): | |||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | |||
elif callable(collate_fns): | |||
@@ -78,96 +81,124 @@ class _MixCollateFn: | |||
else: | |||
self.collate_fns = lambda idx, lst: lst | |||
self.collate_fns = collate_fns | |||
self.auto_collators = auto_collators | |||
def __call__(self, ins_list: List) -> Dict: | |||
""" | |||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | |||
:param ins_list: | |||
:return: | |||
""" | |||
_ins_list, _ds_index = [], 0 | |||
for ins, _ds_index in ins_list: | |||
_ins_list.append(ins) | |||
# auto_collate先处理 | |||
if self.auto_collators is not None: | |||
_ins_list = self.auto_collators[_ds_index](_ins_list) | |||
_ins_list = self.collate_fns(_ds_index, _ins_list) | |||
return _ins_list | |||
class MixDataLoader(DataLoader): | |||
""" | |||
针对一下三种情况提供的MixDataLoader: | |||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||
针对一下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
接一个的 sample 完所有数据。 | |||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
sampler, drop_last, ds_ratio 均无效。 | |||
""" | |||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||
def __init__(self, datasets: Dict = None, mode: Union[str, "Sampler"] = 'sequential', | |||
collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto', | |||
sampler: Union[Dict[str, "Sampler"], str, None] = None, | |||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||
pin_memory: bool = True) -> None: | |||
ds_ratio: Union[None, str, Dict[str, float]] = None, | |||
pin_memory: bool = False) -> None: | |||
""" | |||
:param datasets: dataset的列表 | |||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||
对象,每次迭代返回的是List[int] | |||
:param collate_fn: 对取得到的数据进行打包的callable函数, | |||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||
其对应关系与datasets位置匹配; | |||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||
sampler为List[Sampler],则datasets也为List,且一一对应 | |||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||
其大于0,可以超过1,将数据集重采样翻倍即可 | |||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
接一个的 sample 完所有数据。 | |||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
sampler, drop_last, ds_ratio 均无效。 | |||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class: `~fastNLP.core.collators.Collator` 作为其默认值, | |||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||
* collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||
``[None, str, Dict[str, "Sampler"]]``: | |||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler] `` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。 | |||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||
到最短长度``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``。 | |||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||
到最大长度``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 | |||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||
""" | |||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||
isinstance(sampler, Dict): | |||
raise ValueError(f"") | |||
if isinstance(collate_fn, list): | |||
if len(collate_fn) != len(datasets): | |||
raise ValueError("the length of collate_fn != datasets!!") | |||
if isinstance(sampler, list): | |||
if len(sampler) != len(datasets): | |||
raise ValueError("the length of sampler != datasets!!") | |||
# Dict类型转化为List,以便于_MixCollateFn处理 | |||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||
if isinstance(sampler, Dict): | |||
for key in datasets.keys(): | |||
if not sampler[key]: | |||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||
# collate_fn 为 dict,则判断是否与 datasets 的 key 相同 | |||
if isinstance(collate_fn, Dict): | |||
if mode == 'mix': | |||
raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'") | |||
for key in datasets.keys(): | |||
if not collate_fn[key]: | |||
raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!") | |||
if isinstance(collate_fn, str) and collate_fn == 'auto': | |||
date_type = None | |||
for idx, ds in enumerate(datasets.values()): | |||
if idx == 0: | |||
date_type = type(ds[0]) | |||
if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)): | |||
raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。" | |||
f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}") | |||
collate_fn = Collator(backend='torch') | |||
# Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset | |||
if isinstance(collate_fn, Dict): | |||
collate_fn = [fn for _, fn in collate_fn.items()] | |||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||
if isinstance(datasets, Dict): | |||
dataset = [ds for _, ds in datasets.items()] | |||
else: | |||
dataset = datasets | |||
auto_collators = [] | |||
for per_ds in dataset: | |||
if isinstance(per_ds, DataSet): | |||
auto_collators.append(per_ds.get_collator()) | |||
else: | |||
# 如果没有对应的collator就设置一个不做任何操作的collator | |||
auto_collators.append(lambda x: x) | |||
# List类型的collate_fn只有两种情况,需要对其进行包裹 | |||
collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||
dataset = [ds for _, ds in datasets.items()] | |||
# 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题 | |||
collate_fn = _MixCollateFn(collate_fn) | |||
if mode == 'sequential': | |||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
@@ -21,9 +21,9 @@ class MixSampler: | |||
mix_sampler的基类 | |||
""" | |||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, | |||
ds_ratio: Union[str, List[float], Dict[str, float]] = None, | |||
def __init__(self, dataset: Dict, batch_size: int = None, | |||
sampler: Union[Dict[str, "Sampler"], None, str] = None, | |||
ds_ratio: Union[str, Dict[str, float]] = None, | |||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||
""" | |||
@@ -32,9 +32,12 @@ class MixSampler: | |||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||
""" | |||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
if isinstance(dataset, Dict) and isinstance(sampler, List): | |||
raise ValueError(f"{sampler} must be dict") | |||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||
if isinstance(sampler, Dict): | |||
for key in dataset.keys(): | |||
if not sampler[key]: | |||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||
if batch_size <= 0: | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got batch_size={}".format(batch_size)) | |||
@@ -46,15 +49,7 @@ class MixSampler: | |||
raise ValueError("if rank>=0 and word_size>=0, sampler must be str") | |||
if sampler is None and (word_size < 0 or rank < 0): | |||
if isinstance(dataset, List): | |||
self.sampler = [SequentialSampler(ds) for ds in dataset] | |||
elif isinstance(dataset, Dict): | |||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||
elif isinstance(sampler, List): | |||
if len(sampler) != len(dataset): | |||
raise ValueError("the length of sampler != the length of sampler") | |||
self.sampler = sampler | |||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||
elif isinstance(sampler, Dict): | |||
self.sampler = sampler | |||
@@ -68,26 +63,7 @@ class MixSampler: | |||
# 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len | |||
sampler_lens, total_lens, sampler_index = [], 0, [] | |||
if isinstance(self.sampler, List): | |||
if ds_ratio is None: | |||
sampler_lens = [len(spl) for spl in self.sampler] | |||
elif ds_ratio == 'pad_to_most': | |||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif ds_ratio == 'truncate_to_least': | |||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif isinstance(ds_ratio, List): | |||
if not all(item >= 0 for item in ds_ratio): | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got ds_ratio={}".format(ds_ratio)) | |||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)] | |||
else: | |||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
total_lens = sum(sampler_lens) | |||
elif isinstance(self.sampler, Dict): | |||
if isinstance(self.sampler, Dict): | |||
if ds_ratio is None: | |||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
@@ -100,7 +76,7 @@ class MixSampler: | |||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||
elif isinstance(ds_ratio, Dict): | |||
if not all(item >= 0 for item in ds_ratio): | |||
if not all([item >= 0 for item in ds_ratio.values()]): | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got ds_ratio={}".format(ds_ratio)) | |||
sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] | |||
@@ -108,7 +84,7 @@ class MixSampler: | |||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
total_lens = sum(sampler_lens) | |||
# sampler为str时候,初始化下移到iter方法中 | |||
# sampler 为 str 时候,初始化下移到 iter 方法中 | |||
if len(sampler_lens) > 0: | |||
sampler_index = [sampler_lens[0]] | |||
for idx in sampler_lens[1:]: | |||
@@ -160,75 +136,37 @@ class DopedSampler(MixSampler): | |||
""" | |||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | |||
""" | |||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, | |||
ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, | |||
def __init__(self, dataset: Dict, batch_size: int = None, | |||
sampler: Union[Dict[str, "Sampler"], str] = None, | |||
ds_ratio: Union[str, None, Dict[str, float]] = None, | |||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||
super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, | |||
sampler=sampler, ds_ratio=ds_ratio, | |||
drop_last=drop_last, rank=rank, word_size=word_size) | |||
def __iter__(self) -> List[int]: | |||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
# sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化 | |||
if isinstance(self.sampler, str): | |||
if self.sampler == 'seq': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
elif self.sampler == 'rand': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(indices)) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||
if isinstance(self.sampler, List): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for spl in self.sampler] | |||
elif self.ds_ratio == 'pad_to_most': | |||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif self.ds_ratio == 'truncate_to_least': | |||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif isinstance(self.ds_ratio, List): | |||
if not all(item >= 0 for item in self.ds_ratio): | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got ds_ratio={}".format(self.ds_ratio)) | |||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
else: | |||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
total_lens = sum(sampler_lens) | |||
elif isinstance(self.sampler, Dict): | |||
if isinstance(self.sampler, Dict): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
@@ -257,11 +195,11 @@ class DopedSampler(MixSampler): | |||
sampler_index.append(temp + idx) | |||
self.num_samplers = sampler_index | |||
self.len_samplers = total_lens | |||
# 每个batch的数据, 总的数据量total_index, 每个数据集的samplers | |||
# 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers | |||
batch_idx, samplers = [], [] | |||
# 如果单机则用所有数据,否则采用多卡 | |||
if self.rank < 0 or self.word_size < 0: | |||
# 根据sampler长度判断是否使用unsigned int 或者unsigned long | |||
# 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long | |||
if self.len_samplers > 42e8: | |||
total_index = array.array('L', list(range(self.len_samplers))) | |||
else: | |||
@@ -274,15 +212,17 @@ class DopedSampler(MixSampler): | |||
else: | |||
total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) | |||
start_idx = 0 | |||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||
end_idx = len(spl) | |||
samplers.append((iter(spl), name, start_idx)) | |||
start_idx += end_idx | |||
# 根据sampler的类型取出每个数据集的sampler | |||
if isinstance(self.sampler, List): | |||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||
samplers = [(iter(spl), idx, base_index) | |||
for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))] | |||
else: | |||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
# samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
# for idx, (name, spl) in enumerate(self.sampler.items())] | |||
# 生成随机数 | |||
np.random.seed(self.epoch) | |||
np.random.shuffle(total_index) | |||
@@ -295,7 +235,7 @@ class DopedSampler(MixSampler): | |||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||
spl = iter(self.sampler[name]) | |||
batch_idx.append(next(spl) + base_index) | |||
samplers[name] = (spl, name, base_index) | |||
samplers[ds_index] = (spl, name, base_index) | |||
if len(batch_idx) == self.batch_size: | |||
yield batch_idx | |||
batch_idx = [] | |||
@@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler): | |||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
if isinstance(self.sampler, str): | |||
if self.sampler == 'seq': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
elif self.sampler == 'rand': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(indices)) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||
if isinstance(self.sampler, List): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for spl in self.sampler] | |||
elif self.ds_ratio == 'pad_to_most': | |||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif self.ds_ratio == 'truncate_to_least': | |||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif isinstance(self.ds_ratio, List): | |||
if not all(item >= 0 for item in self.ds_ratio): | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got ds_ratio={}".format(self.ds_ratio)) | |||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
else: | |||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
total_lens = sum(sampler_lens) | |||
elif isinstance(self.sampler, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
# 根据给定的 ds_ratio 算真正需要处理数据集 | |||
if isinstance(self.sampler, Dict): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
@@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler): | |||
self.len_samplers = total_lens | |||
batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] | |||
if isinstance(self.sampler, List): | |||
if self.word_size > 0 and self.rank >= 0: | |||
sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1] | |||
else: | |||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||
samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in | |||
enumerate(zip(self.sampler, sampler_base_index))] | |||
else: | |||
if self.word_size > 0 and self.rank >= 0: | |||
sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||
else: | |||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||
start_idx = 0 | |||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||
end_idx = len(spl) | |||
samplers.append((iter(spl), name, start_idx)) | |||
start_idx += end_idx | |||
# if self.word_size > 0 and self.rank >= 0: | |||
# sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||
# else: | |||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
# | |||
# samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
# for idx, (name, spl) in enumerate(self.sampler.items())] | |||
for idx in total_index: | |||
ds_index = np.searchsorted(self.num_samplers, idx, side='right') | |||
@@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler): | |||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||
spl = iter(self.sampler[name]) | |||
batch_idx.append(next(spl) + base_index) | |||
samplers[name] = (spl, name, base_index) | |||
samplers[ds_index] = (spl, name, base_index) | |||
if len(batch_idx) == self.batch_size: | |||
yield batch_idx | |||
batch_idx = [] | |||
@@ -506,63 +408,26 @@ class PollingSampler(MixSampler): | |||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
if isinstance(self.sampler, str): | |||
if self.sampler == 'seq': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
elif self.sampler == 'rand': | |||
if isinstance(self.datasets, List): | |||
self.sampler = [] | |||
for per_ds in self.datasets: | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
else: | |||
self.sampler.append(InnerSampler(indices)) | |||
elif isinstance(self.datasets, Dict): | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||
if isinstance(self.sampler, List): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for spl in self.sampler] | |||
elif self.ds_ratio == 'pad_to_most': | |||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
self.sampler = {} | |||
for name, per_ds in self.datasets.items(): | |||
g = torch.Generator() | |||
g.manual_seed(self.epoch) | |||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
if self.word_size >= 0 and self.rank >= 0: | |||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
else: | |||
self.sampler[name] = InnerSampler(indices) | |||
elif self.ds_ratio == 'truncate_to_least': | |||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
elif isinstance(self.ds_ratio, List): | |||
if not all(item >= 0 for item in self.ds_ratio): | |||
raise ValueError("batch_size should be a positive integer value, " | |||
"but got ds_ratio={}".format(self.ds_ratio)) | |||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
else: | |||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
total_lens = sum(sampler_lens) | |||
elif isinstance(self.sampler, Dict): | |||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||
if isinstance(self.sampler, Dict): | |||
if self.ds_ratio is None: | |||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
@@ -592,17 +457,15 @@ class PollingSampler(MixSampler): | |||
self.num_samplers = sampler_index | |||
self.len_samplers = total_lens | |||
start_idx, samplers = 0, [] | |||
if isinstance(self.sampler, List): | |||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)): | |||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx)) | |||
start_idx = end_idx | |||
else: | |||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||
end_idx = self.num_samplers[idx] | |||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name)) | |||
start_idx = end_idx | |||
start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0 | |||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||
end_idx = len(spl) | |||
true_end_idx = self.num_samplers[idx] | |||
samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name)) | |||
start_idx += end_idx | |||
true_start_idx = true_end_idx | |||
while True: | |||
# 退出循环 | |||
@@ -0,0 +1,495 @@ | |||
import pytest | |||
from typing import Mapping | |||
from fastNLP.core.dataloaders import MixDataLoader | |||
from fastNLP import DataSet | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.utils.data import default_collate, SequentialSampler, RandomSampler | |||
d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10}) | |||
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) | |||
def test_pad_val(tensor, val=0): | |||
if isinstance(tensor, torch.Tensor): | |||
tensor = tensor.tolist() | |||
for item in tensor: | |||
if item[-1] > 0: | |||
continue | |||
elif item[-1] != val: | |||
return False | |||
return True | |||
class TestMixDataLoader: | |||
def test_sequential_init(self): | |||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
# drop_last = True, collate_fn = 'auto | |||
dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True) | |||
for idx, batch in enumerate(dl): | |||
if idx == 0: | |||
# d1 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
# d2 | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx > 1: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
# collate_fn = Callable | |||
def collate_batch(batch): | |||
new_batch = {'x': [], 'y': []} | |||
for ins in batch: | |||
new_batch['x'].append(ins['x']) | |||
new_batch['y'].append(ins['y']) | |||
return new_batch | |||
dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True) | |||
for idx, batch in enumerate(dl1): | |||
if idx == 0: | |||
# d1 | |||
assert [1, 2] in batch['x'] | |||
if idx == 1: | |||
# d2 | |||
assert [101, 201] in batch['x'] | |||
if idx > 1: | |||
# d3 | |||
assert [1000, 2000] in batch['x'] | |||
assert 'x' in batch and 'y' in batch | |||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
'd2': Collator(backend='auto').set_pad("x", -2), | |||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) | |||
for idx, batch in enumerate(dl2): | |||
if idx == 0: | |||
assert test_pad_val(batch['x'], val=-1) | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
assert test_pad_val(batch['x'], val=-2) | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx > 1: | |||
assert test_pad_val(batch['x'], val=-3) | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
# sampler 为 str | |||
dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True) | |||
dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True) | |||
for idx, batch in enumerate(dl3): | |||
if idx == 0: | |||
# d1 | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx == 2: | |||
# d3 | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
if idx > 1: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
for idx, batch in enumerate(dl4): | |||
if idx == 0: | |||
# d1 | |||
assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx == 2: | |||
# d3 | |||
assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
if idx > 1: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
# sampler 为 Dict | |||
samplers = {'d1': SequentialSampler(d1), | |||
'd2': SequentialSampler(d2), | |||
'd3': RandomSampler(d3)} | |||
dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True) | |||
for idx, batch in enumerate(dl5): | |||
if idx == 0: | |||
# d1 | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx > 1: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
# ds_ratio 为 'truncate_to_least' | |||
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) | |||
for idx, batch in enumerate(dl6): | |||
if idx == 0: | |||
# d1 | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if idx == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if idx == 2: | |||
# d3 | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx > 2: | |||
raise ValueError(f"ds_ratio: 'truncate_to_least' error") | |||
# ds_ratio 为 'pad_to_most' | |||
dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True) | |||
for idx, batch in enumerate(dl7): | |||
if idx < 18: | |||
# d1 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if 18 <= idx < 36: | |||
# d2 | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if 36 <= idx < 54: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 54: | |||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
# ds_ratio 为 Dict[str, float] | |||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||
for idx, batch in enumerate(dl8): | |||
if idx < 1: | |||
# d1 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
if 1 <= idx < 4: | |||
# d2 | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if 4 <= idx < 41: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 41: | |||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||
for idx, batch in enumerate(dl9): | |||
if idx < 1: | |||
# d2 | |||
assert batch['x'].shape == torch.Size([16, 3]) | |||
if 1 <= idx < 19: | |||
# d3 | |||
assert batch['x'].shape == torch.Size([16, 4]) | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 19: | |||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
def test_mix(self): | |||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) | |||
for idx, batch in enumerate(dl): | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 22: | |||
raise ValueError(f"out of range") | |||
# collate_fn = Callable | |||
def collate_batch(batch): | |||
new_batch = {'x': [], 'y': []} | |||
for ins in batch: | |||
new_batch['x'].append(ins['x']) | |||
new_batch['y'].append(ins['y']) | |||
return new_batch | |||
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) | |||
for idx, batch in enumerate(dl1): | |||
assert isinstance(batch['x'], list) | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 22: | |||
raise ValueError(f"out of range") | |||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
'd2': Collator(backend='auto').set_pad("x", -2), | |||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||
with pytest.raises(ValueError): | |||
MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns) | |||
# sampler 为 str | |||
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) | |||
for idx, batch in enumerate(dl3): | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 22: | |||
raise ValueError(f"out of range") | |||
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) | |||
for idx, batch in enumerate(dl4): | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 22: | |||
raise ValueError(f"out of range") | |||
# sampler 为 Dict | |||
samplers = {'d1': SequentialSampler(d1), | |||
'd2': SequentialSampler(d2), | |||
'd3': RandomSampler(d3)} | |||
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) | |||
for idx, batch in enumerate(dl5): | |||
assert test_pad_val(batch['x'], val=0) | |||
if idx >= 22: | |||
raise ValueError(f"out of range") | |||
# ds_ratio 为 'truncate_to_least' | |||
dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least') | |||
d1_len, d2_len, d3_len = 0, 0, 0 | |||
for idx, batch in enumerate(dl6): | |||
for item in batch['y'].tolist(): | |||
if item in [1, 0, 1]: | |||
d1_len += 1 | |||
elif item in [20, 10, 10]: | |||
d2_len += 1 | |||
elif item in [100, 100, 200]: | |||
d3_len += 1 | |||
if idx >= 6: | |||
raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了") | |||
assert d1_len == d2_len == d3_len == 30 | |||
# ds_ratio 为 'pad_to_most' | |||
dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most') | |||
d1_len, d2_len, d3_len = 0, 0, 0 | |||
for idx, batch in enumerate(dl7): | |||
for item in batch['y'].tolist(): | |||
if item in [1, 0, 1]: | |||
d1_len += 1 | |||
elif item in [20, 10, 10]: | |||
d2_len += 1 | |||
elif item in [100, 100, 200]: | |||
d3_len += 1 | |||
if idx >= 57: | |||
raise ValueError(f"ds_ratio 为 'pad_to_most'出错了") | |||
assert d1_len == d2_len == d3_len == 300 | |||
# ds_ratio 为 Dict[str, float] | |||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||
d1_len, d2_len, d3_len = 0, 0, 0 | |||
for idx, batch in enumerate(dl8): | |||
for item in batch['y'].tolist(): | |||
if item in [1, 0, 1]: | |||
d1_len += 1 | |||
elif item in [20, 10, 10]: | |||
d2_len += 1 | |||
elif item in [100, 100, 200]: | |||
d3_len += 1 | |||
if idx >= 44: | |||
raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||
assert d1_len == 30 | |||
assert d2_len == 60 | |||
assert d3_len == 600 | |||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||
d1_len, d2_len, d3_len = 0, 0, 0 | |||
for idx, batch in enumerate(dl9): | |||
for item in batch['y'].tolist(): | |||
if item in [1, 0, 1]: | |||
d1_len += 1 | |||
elif item in [20, 10, 10]: | |||
d2_len += 1 | |||
elif item in [100, 100, 200]: | |||
d3_len += 1 | |||
if idx >= 21: | |||
raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||
def test_polling(self): | |||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18) | |||
for idx, batch in enumerate(dl): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
# collate_fn = Callable | |||
def collate_batch(batch): | |||
new_batch = {'x': [], 'y': []} | |||
for ins in batch: | |||
new_batch['x'].append(ins['x']) | |||
new_batch['y'].append(ins['y']) | |||
return new_batch | |||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18) | |||
for idx, batch in enumerate(dl1): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]] | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]] | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
'd2': Collator(backend='auto').set_pad("x", -2), | |||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) | |||
for idx, batch in enumerate(dl1): | |||
if idx == 0 or idx == 3: | |||
assert test_pad_val(batch['x'], val=-1) | |||
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert test_pad_val(batch['x'], val=-2) | |||
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert test_pad_val(batch['x'], val=-3) | |||
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
# sampler 为 str | |||
dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18) | |||
dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18) | |||
for idx, batch in enumerate(dl2): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
for idx, batch in enumerate(dl3): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
# sampler 为 Dict | |||
samplers = {'d1': SequentialSampler(d1), | |||
'd2': SequentialSampler(d2), | |||
'd3': RandomSampler(d3)} | |||
dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18) | |||
for idx, batch in enumerate(dl4): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or 4 < idx <= 20: | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 20: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
# ds_ratio 为 'truncate_to_least' | |||
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) | |||
for idx, batch in enumerate(dl5): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or idx == 5: | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 5: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
# ds_ratio 为 'pad_to_most' | |||
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) | |||
for idx, batch in enumerate(dl6): | |||
if idx % 3 == 0: | |||
# d1 | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx % 3 == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
if idx % 3 == 2: | |||
# d3 | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx >= 51: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
# ds_ratio 为 Dict[str, float] | |||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||
for idx, batch in enumerate(dl7): | |||
if idx == 0 or idx == 3: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1 or idx == 4 or idx == 6 or idx == 8: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx == 2 or idx == 5 or idx == 7 or idx > 8: | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 39: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) | |||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||
for idx, batch in enumerate(dl8): | |||
if idx == 0: | |||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
assert batch['x'].shape[1] == 4 | |||
elif idx == 1: | |||
# d2 | |||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
assert batch['x'].shape[1] == 3 | |||
elif idx > 1: | |||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
assert batch['x'].shape[1] == 4 | |||
if idx > 18: | |||
raise ValueError(f"out of range") | |||
test_pad_val(batch['x'], val=0) |