@@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||||
class _JittorDataset(Dataset): | class _JittorDataset(Dataset): | ||||
""" | """ | ||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | ||||
""" | """ | ||||
def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
@@ -37,7 +37,7 @@ class _JittorDataset(Dataset): | |||||
item = item.tolist() | item = item.tolist() | ||||
return (item, self.dataset[item]) | return (item, self.dataset[item]) | ||||
class JittorDataLoader: | class JittorDataLoader: | ||||
""" | """ | ||||
提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | 提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | ||||
@@ -2,13 +2,14 @@ __all__ = [ | |||||
'MixDataLoader' | 'MixDataLoader' | ||||
] | ] | ||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet, Instance | from fastNLP.core.dataset import DataSet, Instance | ||||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.collators import Collator | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
@@ -18,12 +19,13 @@ else: | |||||
class _MixDataset: | class _MixDataset: | ||||
""" | """ | ||||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||||
将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index | |||||
""" | """ | ||||
def __init__(self, datasets: list = None) -> None: | def __init__(self, datasets: list = None) -> None: | ||||
""" | """ | ||||
:param datasets: 数据集的列表 | |||||
:param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列 | |||||
""" | """ | ||||
self.datasets = datasets | self.datasets = datasets | ||||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | # 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | ||||
@@ -35,7 +37,7 @@ class _MixDataset: | |||||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | ||||
""" | """ | ||||
根据index索引获取数据 | |||||
根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回 | |||||
:param idx: 整数类型的index或者列表 | :param idx: 整数类型的index或者列表 | ||||
:return: | :return: | ||||
@@ -69,8 +71,9 @@ class _MixCollateFn: | |||||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | 存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | ||||
""" | """ | ||||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||||
auto_collators: Optional[List[Callable]] = None) -> None: | |||||
def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None: | |||||
if isinstance(collate_fns, Sequence): | if isinstance(collate_fns, Sequence): | ||||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | ||||
elif callable(collate_fns): | elif callable(collate_fns): | ||||
@@ -78,96 +81,124 @@ class _MixCollateFn: | |||||
else: | else: | ||||
self.collate_fns = lambda idx, lst: lst | self.collate_fns = lambda idx, lst: lst | ||||
self.collate_fns = collate_fns | |||||
self.auto_collators = auto_collators | |||||
def __call__(self, ins_list: List) -> Dict: | def __call__(self, ins_list: List) -> Dict: | ||||
""" | """ | ||||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | 调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | ||||
:param ins_list: | :param ins_list: | ||||
:return: | :return: | ||||
""" | """ | ||||
_ins_list, _ds_index = [], 0 | _ins_list, _ds_index = [], 0 | ||||
for ins, _ds_index in ins_list: | for ins, _ds_index in ins_list: | ||||
_ins_list.append(ins) | _ins_list.append(ins) | ||||
# auto_collate先处理 | |||||
if self.auto_collators is not None: | |||||
_ins_list = self.auto_collators[_ds_index](_ins_list) | |||||
_ins_list = self.collate_fns(_ds_index, _ins_list) | _ins_list = self.collate_fns(_ds_index, _ins_list) | ||||
return _ins_list | return _ins_list | ||||
class MixDataLoader(DataLoader): | class MixDataLoader(DataLoader): | ||||
""" | """ | ||||
针对一下三种情况提供的MixDataLoader: | |||||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||||
针对一下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||||
接一个的 sample 完所有数据。 | |||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||||
sampler, drop_last, ds_ratio 均无效。 | |||||
""" | """ | ||||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||||
def __init__(self, datasets: Dict = None, mode: Union[str, "Sampler"] = 'sequential', | |||||
collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto', | |||||
sampler: Union[Dict[str, "Sampler"], str, None] = None, | |||||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | num_workers: int = 0, batch_size: int = 16, drop_last=False, | ||||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||||
pin_memory: bool = True) -> None: | |||||
ds_ratio: Union[None, str, Dict[str, float]] = None, | |||||
pin_memory: bool = False) -> None: | |||||
""" | """ | ||||
:param datasets: dataset的列表 | |||||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||||
对象,每次迭代返回的是List[int] | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数, | |||||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||||
其对应关系与datasets位置匹配; | |||||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||||
sampler为List[Sampler],则datasets也为List,且一一对应 | |||||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||||
其大于0,可以超过1,将数据集重采样翻倍即可 | |||||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||||
:param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||||
:param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||||
* 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||||
接一个的 sample 完所有数据。 | |||||
* 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||||
混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||||
* 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||||
再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||||
* 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||||
且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||||
sampler, drop_last, ds_ratio 均无效。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||||
* collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class: `~fastNLP.core.collators.Collator` 作为其默认值, | |||||
需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||||
* collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||||
batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
* collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||||
用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||||
``[None, str, Dict[str, "Sampler"]]``: | |||||
* sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||||
作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||||
* sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||||
Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler] `` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||||
也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||||
:param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||||
* ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||||
* ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||||
到最短长度``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``。 | |||||
* ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||||
到最大长度``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 | |||||
* ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||||
代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||||
""" | """ | ||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||||
isinstance(sampler, Dict): | |||||
raise ValueError(f"") | |||||
if isinstance(collate_fn, list): | |||||
if len(collate_fn) != len(datasets): | |||||
raise ValueError("the length of collate_fn != datasets!!") | |||||
if isinstance(sampler, list): | |||||
if len(sampler) != len(datasets): | |||||
raise ValueError("the length of sampler != datasets!!") | |||||
# Dict类型转化为List,以便于_MixCollateFn处理 | |||||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||||
if isinstance(sampler, Dict): | |||||
for key in datasets.keys(): | |||||
if not sampler[key]: | |||||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||||
# collate_fn 为 dict,则判断是否与 datasets 的 key 相同 | |||||
if isinstance(collate_fn, Dict): | |||||
if mode == 'mix': | |||||
raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'") | |||||
for key in datasets.keys(): | |||||
if not collate_fn[key]: | |||||
raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!") | |||||
if isinstance(collate_fn, str) and collate_fn == 'auto': | |||||
date_type = None | |||||
for idx, ds in enumerate(datasets.values()): | |||||
if idx == 0: | |||||
date_type = type(ds[0]) | |||||
if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)): | |||||
raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。" | |||||
f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}") | |||||
collate_fn = Collator(backend='torch') | |||||
# Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset | |||||
if isinstance(collate_fn, Dict): | if isinstance(collate_fn, Dict): | ||||
collate_fn = [fn for _, fn in collate_fn.items()] | collate_fn = [fn for _, fn in collate_fn.items()] | ||||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||||
if isinstance(datasets, Dict): | |||||
dataset = [ds for _, ds in datasets.items()] | |||||
else: | |||||
dataset = datasets | |||||
auto_collators = [] | |||||
for per_ds in dataset: | |||||
if isinstance(per_ds, DataSet): | |||||
auto_collators.append(per_ds.get_collator()) | |||||
else: | |||||
# 如果没有对应的collator就设置一个不做任何操作的collator | |||||
auto_collators.append(lambda x: x) | |||||
# List类型的collate_fn只有两种情况,需要对其进行包裹 | |||||
collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||||
dataset = [ds for _, ds in datasets.items()] | |||||
# 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题 | |||||
collate_fn = _MixCollateFn(collate_fn) | |||||
if mode == 'sequential': | if mode == 'sequential': | ||||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | ||||
drop_last=drop_last, ds_ratio=ds_ratio) | drop_last=drop_last, ds_ratio=ds_ratio) | ||||
@@ -21,9 +21,9 @@ class MixSampler: | |||||
mix_sampler的基类 | mix_sampler的基类 | ||||
""" | """ | ||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, | |||||
ds_ratio: Union[str, List[float], Dict[str, float]] = None, | |||||
def __init__(self, dataset: Dict, batch_size: int = None, | |||||
sampler: Union[Dict[str, "Sampler"], None, str] = None, | |||||
ds_ratio: Union[str, Dict[str, float]] = None, | |||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | ||||
""" | """ | ||||
@@ -32,9 +32,12 @@ class MixSampler: | |||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | ||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | ||||
""" | """ | ||||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||||
if isinstance(dataset, Dict) and isinstance(sampler, List): | |||||
raise ValueError(f"{sampler} must be dict") | |||||
# sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||||
if isinstance(sampler, Dict): | |||||
for key in dataset.keys(): | |||||
if not sampler[key]: | |||||
raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||||
if batch_size <= 0: | if batch_size <= 0: | ||||
raise ValueError("batch_size should be a positive integer value, " | raise ValueError("batch_size should be a positive integer value, " | ||||
"but got batch_size={}".format(batch_size)) | "but got batch_size={}".format(batch_size)) | ||||
@@ -46,15 +49,7 @@ class MixSampler: | |||||
raise ValueError("if rank>=0 and word_size>=0, sampler must be str") | raise ValueError("if rank>=0 and word_size>=0, sampler must be str") | ||||
if sampler is None and (word_size < 0 or rank < 0): | if sampler is None and (word_size < 0 or rank < 0): | ||||
if isinstance(dataset, List): | |||||
self.sampler = [SequentialSampler(ds) for ds in dataset] | |||||
elif isinstance(dataset, Dict): | |||||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
elif isinstance(sampler, List): | |||||
if len(sampler) != len(dataset): | |||||
raise ValueError("the length of sampler != the length of sampler") | |||||
self.sampler = sampler | |||||
self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||||
elif isinstance(sampler, Dict): | elif isinstance(sampler, Dict): | ||||
self.sampler = sampler | self.sampler = sampler | ||||
@@ -68,26 +63,7 @@ class MixSampler: | |||||
# 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len | # 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len | ||||
sampler_lens, total_lens, sampler_index = [], 0, [] | sampler_lens, total_lens, sampler_index = [], 0, [] | ||||
if isinstance(self.sampler, List): | |||||
if ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(ds_ratio, List): | |||||
if not all(item >= 0 for item in ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if isinstance(self.sampler, Dict): | |||||
if ds_ratio is None: | if ds_ratio is None: | ||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | sampler_lens = [len(spl) for _, spl in self.sampler.items()] | ||||
@@ -100,7 +76,7 @@ class MixSampler: | |||||
sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | ||||
elif isinstance(ds_ratio, Dict): | elif isinstance(ds_ratio, Dict): | ||||
if not all(item >= 0 for item in ds_ratio): | |||||
if not all([item >= 0 for item in ds_ratio.values()]): | |||||
raise ValueError("batch_size should be a positive integer value, " | raise ValueError("batch_size should be a positive integer value, " | ||||
"but got ds_ratio={}".format(ds_ratio)) | "but got ds_ratio={}".format(ds_ratio)) | ||||
sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] | sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] | ||||
@@ -108,7 +84,7 @@ class MixSampler: | |||||
raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | ||||
total_lens = sum(sampler_lens) | total_lens = sum(sampler_lens) | ||||
# sampler为str时候,初始化下移到iter方法中 | |||||
# sampler 为 str 时候,初始化下移到 iter 方法中 | |||||
if len(sampler_lens) > 0: | if len(sampler_lens) > 0: | ||||
sampler_index = [sampler_lens[0]] | sampler_index = [sampler_lens[0]] | ||||
for idx in sampler_lens[1:]: | for idx in sampler_lens[1:]: | ||||
@@ -160,75 +136,37 @@ class DopedSampler(MixSampler): | |||||
""" | """ | ||||
定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | ||||
""" | """ | ||||
def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||||
sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, | |||||
ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, | |||||
def __init__(self, dataset: Dict, batch_size: int = None, | |||||
sampler: Union[Dict[str, "Sampler"], str] = None, | |||||
ds_ratio: Union[str, None, Dict[str, float]] = None, | |||||
drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | ||||
super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, | super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, | ||||
sampler=sampler, ds_ratio=ds_ratio, | sampler=sampler, ds_ratio=ds_ratio, | ||||
drop_last=drop_last, rank=rank, word_size=word_size) | drop_last=drop_last, rank=rank, word_size=word_size) | ||||
def __iter__(self) -> List[int]: | def __iter__(self) -> List[int]: | ||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||||
# sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化 | |||||
if isinstance(self.sampler, str): | if isinstance(self.sampler, str): | ||||
if self.sampler == 'seq': | if self.sampler == 'seq': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | elif self.sampler == 'rand': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | # 根据给定的ds_ratio计算真正需要处理数据集 | ||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
if isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | if self.ds_ratio is None: | ||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | sampler_lens = [len(spl) for _, spl in self.sampler.items()] | ||||
@@ -257,11 +195,11 @@ class DopedSampler(MixSampler): | |||||
sampler_index.append(temp + idx) | sampler_index.append(temp + idx) | ||||
self.num_samplers = sampler_index | self.num_samplers = sampler_index | ||||
self.len_samplers = total_lens | self.len_samplers = total_lens | ||||
# 每个batch的数据, 总的数据量total_index, 每个数据集的samplers | |||||
# 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers | |||||
batch_idx, samplers = [], [] | batch_idx, samplers = [], [] | ||||
# 如果单机则用所有数据,否则采用多卡 | # 如果单机则用所有数据,否则采用多卡 | ||||
if self.rank < 0 or self.word_size < 0: | if self.rank < 0 or self.word_size < 0: | ||||
# 根据sampler长度判断是否使用unsigned int 或者unsigned long | |||||
# 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long | |||||
if self.len_samplers > 42e8: | if self.len_samplers > 42e8: | ||||
total_index = array.array('L', list(range(self.len_samplers))) | total_index = array.array('L', list(range(self.len_samplers))) | ||||
else: | else: | ||||
@@ -274,15 +212,17 @@ class DopedSampler(MixSampler): | |||||
else: | else: | ||||
total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) | total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) | ||||
start_idx = 0 | |||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||||
end_idx = len(spl) | |||||
samplers.append((iter(spl), name, start_idx)) | |||||
start_idx += end_idx | |||||
# 根据sampler的类型取出每个数据集的sampler | # 根据sampler的类型取出每个数据集的sampler | ||||
if isinstance(self.sampler, List): | |||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||||
samplers = [(iter(spl), idx, base_index) | |||||
for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
# samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
# for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
# 生成随机数 | # 生成随机数 | ||||
np.random.seed(self.epoch) | np.random.seed(self.epoch) | ||||
np.random.shuffle(total_index) | np.random.shuffle(total_index) | ||||
@@ -295,7 +235,7 @@ class DopedSampler(MixSampler): | |||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | ||||
spl = iter(self.sampler[name]) | spl = iter(self.sampler[name]) | ||||
batch_idx.append(next(spl) + base_index) | batch_idx.append(next(spl) + base_index) | ||||
samplers[name] = (spl, name, base_index) | |||||
samplers[ds_index] = (spl, name, base_index) | |||||
if len(batch_idx) == self.batch_size: | if len(batch_idx) == self.batch_size: | ||||
yield batch_idx | yield batch_idx | ||||
batch_idx = [] | batch_idx = [] | ||||
@@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler): | |||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | ||||
if isinstance(self.sampler, str): | if isinstance(self.sampler, str): | ||||
if self.sampler == 'seq': | if self.sampler == 'seq': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | elif self.sampler == 'rand': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的 ds_ratio 算真正需要处理数据集 | |||||
if isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | if self.ds_ratio is None: | ||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | sampler_lens = [len(spl) for _, spl in self.sampler.items()] | ||||
@@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler): | |||||
self.len_samplers = total_lens | self.len_samplers = total_lens | ||||
batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] | batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] | ||||
if isinstance(self.sampler, List): | |||||
if self.word_size > 0 and self.rank >= 0: | |||||
sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||||
samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in | |||||
enumerate(zip(self.sampler, sampler_base_index))] | |||||
else: | |||||
if self.word_size > 0 and self.rank >= 0: | |||||
sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||||
else: | |||||
sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
start_idx = 0 | |||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||||
end_idx = len(spl) | |||||
samplers.append((iter(spl), name, start_idx)) | |||||
start_idx += end_idx | |||||
# if self.word_size > 0 and self.rank >= 0: | |||||
# sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||||
# else: | |||||
# sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||||
# | |||||
# samplers = [(iter(spl), name, sampler_base_index[idx]) | |||||
# for idx, (name, spl) in enumerate(self.sampler.items())] | |||||
for idx in total_index: | for idx in total_index: | ||||
ds_index = np.searchsorted(self.num_samplers, idx, side='right') | ds_index = np.searchsorted(self.num_samplers, idx, side='right') | ||||
@@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler): | |||||
# 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | ||||
spl = iter(self.sampler[name]) | spl = iter(self.sampler[name]) | ||||
batch_idx.append(next(spl) + base_index) | batch_idx.append(next(spl) + base_index) | ||||
samplers[name] = (spl, name, base_index) | |||||
samplers[ds_index] = (spl, name, base_index) | |||||
if len(batch_idx) == self.batch_size: | if len(batch_idx) == self.batch_size: | ||||
yield batch_idx | yield batch_idx | ||||
batch_idx = [] | batch_idx = [] | ||||
@@ -506,63 +408,26 @@ class PollingSampler(MixSampler): | |||||
# sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | ||||
if isinstance(self.sampler, str): | if isinstance(self.sampler, str): | ||||
if self.sampler == 'seq': | if self.sampler == 'seq': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||||
elif self.sampler == 'rand': | elif self.sampler == 'rand': | ||||
if isinstance(self.datasets, List): | |||||
self.sampler = [] | |||||
for per_ds in self.datasets: | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||||
else: | |||||
self.sampler.append(InnerSampler(indices)) | |||||
elif isinstance(self.datasets, Dict): | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, List): | |||||
if self.ds_ratio is None: | |||||
sampler_lens = [len(spl) for spl in self.sampler] | |||||
elif self.ds_ratio == 'pad_to_most': | |||||
sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
self.sampler = {} | |||||
for name, per_ds in self.datasets.items(): | |||||
g = torch.Generator() | |||||
g.manual_seed(self.epoch) | |||||
indices = torch.randperm(len(per_ds), generator=g).tolist() | |||||
if self.word_size >= 0 and self.rank >= 0: | |||||
self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||||
else: | |||||
self.sampler[name] = InnerSampler(indices) | |||||
elif self.ds_ratio == 'truncate_to_least': | |||||
sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||||
elif isinstance(self.ds_ratio, List): | |||||
if not all(item >= 0 for item in self.ds_ratio): | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got ds_ratio={}".format(self.ds_ratio)) | |||||
sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||||
else: | |||||
raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||||
total_lens = sum(sampler_lens) | |||||
elif isinstance(self.sampler, Dict): | |||||
# 根据给定的ds_ratio计算真正需要处理数据集 | |||||
if isinstance(self.sampler, Dict): | |||||
if self.ds_ratio is None: | if self.ds_ratio is None: | ||||
sampler_lens = [len(spl) for _, spl in self.sampler.items()] | sampler_lens = [len(spl) for _, spl in self.sampler.items()] | ||||
@@ -592,17 +457,15 @@ class PollingSampler(MixSampler): | |||||
self.num_samplers = sampler_index | self.num_samplers = sampler_index | ||||
self.len_samplers = total_lens | self.len_samplers = total_lens | ||||
start_idx, samplers = 0, [] | |||||
if isinstance(self.sampler, List): | |||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||||
for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)): | |||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx)) | |||||
start_idx = end_idx | |||||
else: | |||||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||||
end_idx = self.num_samplers[idx] | |||||
samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name)) | |||||
start_idx = end_idx | |||||
start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0 | |||||
# (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||||
for idx, (name, spl) in enumerate(self.sampler.items()): | |||||
end_idx = len(spl) | |||||
true_end_idx = self.num_samplers[idx] | |||||
samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name)) | |||||
start_idx += end_idx | |||||
true_start_idx = true_end_idx | |||||
while True: | while True: | ||||
# 退出循环 | # 退出循环 | ||||
@@ -0,0 +1,495 @@ | |||||
import pytest | |||||
from typing import Mapping | |||||
from fastNLP.core.dataloaders import MixDataLoader | |||||
from fastNLP import DataSet | |||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
from torch.utils.data import default_collate, SequentialSampler, RandomSampler | |||||
d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10}) | |||||
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) | |||||
def test_pad_val(tensor, val=0): | |||||
if isinstance(tensor, torch.Tensor): | |||||
tensor = tensor.tolist() | |||||
for item in tensor: | |||||
if item[-1] > 0: | |||||
continue | |||||
elif item[-1] != val: | |||||
return False | |||||
return True | |||||
class TestMixDataLoader: | |||||
def test_sequential_init(self): | |||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||||
# drop_last = True, collate_fn = 'auto | |||||
dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True) | |||||
for idx, batch in enumerate(dl): | |||||
if idx == 0: | |||||
# d1 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
# d2 | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx > 1: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
# collate_fn = Callable | |||||
def collate_batch(batch): | |||||
new_batch = {'x': [], 'y': []} | |||||
for ins in batch: | |||||
new_batch['x'].append(ins['x']) | |||||
new_batch['y'].append(ins['y']) | |||||
return new_batch | |||||
dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True) | |||||
for idx, batch in enumerate(dl1): | |||||
if idx == 0: | |||||
# d1 | |||||
assert [1, 2] in batch['x'] | |||||
if idx == 1: | |||||
# d2 | |||||
assert [101, 201] in batch['x'] | |||||
if idx > 1: | |||||
# d3 | |||||
assert [1000, 2000] in batch['x'] | |||||
assert 'x' in batch and 'y' in batch | |||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||||
'd2': Collator(backend='auto').set_pad("x", -2), | |||||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||||
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) | |||||
for idx, batch in enumerate(dl2): | |||||
if idx == 0: | |||||
assert test_pad_val(batch['x'], val=-1) | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
assert test_pad_val(batch['x'], val=-2) | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx > 1: | |||||
assert test_pad_val(batch['x'], val=-3) | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
# sampler 为 str | |||||
dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True) | |||||
dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True) | |||||
for idx, batch in enumerate(dl3): | |||||
if idx == 0: | |||||
# d1 | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx == 2: | |||||
# d3 | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
if idx > 1: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
for idx, batch in enumerate(dl4): | |||||
if idx == 0: | |||||
# d1 | |||||
assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx == 2: | |||||
# d3 | |||||
assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
if idx > 1: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
# sampler 为 Dict | |||||
samplers = {'d1': SequentialSampler(d1), | |||||
'd2': SequentialSampler(d2), | |||||
'd3': RandomSampler(d3)} | |||||
dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True) | |||||
for idx, batch in enumerate(dl5): | |||||
if idx == 0: | |||||
# d1 | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx > 1: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
# ds_ratio 为 'truncate_to_least' | |||||
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) | |||||
for idx, batch in enumerate(dl6): | |||||
if idx == 0: | |||||
# d1 | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if idx == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if idx == 2: | |||||
# d3 | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx > 2: | |||||
raise ValueError(f"ds_ratio: 'truncate_to_least' error") | |||||
# ds_ratio 为 'pad_to_most' | |||||
dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True) | |||||
for idx, batch in enumerate(dl7): | |||||
if idx < 18: | |||||
# d1 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if 18 <= idx < 36: | |||||
# d2 | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if 36 <= idx < 54: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 54: | |||||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||||
# ds_ratio 为 Dict[str, float] | |||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||||
dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||||
for idx, batch in enumerate(dl8): | |||||
if idx < 1: | |||||
# d1 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
if 1 <= idx < 4: | |||||
# d2 | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if 4 <= idx < 41: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 41: | |||||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||||
dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||||
for idx, batch in enumerate(dl9): | |||||
if idx < 1: | |||||
# d2 | |||||
assert batch['x'].shape == torch.Size([16, 3]) | |||||
if 1 <= idx < 19: | |||||
# d3 | |||||
assert batch['x'].shape == torch.Size([16, 4]) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 19: | |||||
raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||||
def test_mix(self): | |||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||||
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) | |||||
for idx, batch in enumerate(dl): | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 22: | |||||
raise ValueError(f"out of range") | |||||
# collate_fn = Callable | |||||
def collate_batch(batch): | |||||
new_batch = {'x': [], 'y': []} | |||||
for ins in batch: | |||||
new_batch['x'].append(ins['x']) | |||||
new_batch['y'].append(ins['y']) | |||||
return new_batch | |||||
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) | |||||
for idx, batch in enumerate(dl1): | |||||
assert isinstance(batch['x'], list) | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 22: | |||||
raise ValueError(f"out of range") | |||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||||
'd2': Collator(backend='auto').set_pad("x", -2), | |||||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||||
with pytest.raises(ValueError): | |||||
MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns) | |||||
# sampler 为 str | |||||
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) | |||||
for idx, batch in enumerate(dl3): | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 22: | |||||
raise ValueError(f"out of range") | |||||
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) | |||||
for idx, batch in enumerate(dl4): | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 22: | |||||
raise ValueError(f"out of range") | |||||
# sampler 为 Dict | |||||
samplers = {'d1': SequentialSampler(d1), | |||||
'd2': SequentialSampler(d2), | |||||
'd3': RandomSampler(d3)} | |||||
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) | |||||
for idx, batch in enumerate(dl5): | |||||
assert test_pad_val(batch['x'], val=0) | |||||
if idx >= 22: | |||||
raise ValueError(f"out of range") | |||||
# ds_ratio 为 'truncate_to_least' | |||||
dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least') | |||||
d1_len, d2_len, d3_len = 0, 0, 0 | |||||
for idx, batch in enumerate(dl6): | |||||
for item in batch['y'].tolist(): | |||||
if item in [1, 0, 1]: | |||||
d1_len += 1 | |||||
elif item in [20, 10, 10]: | |||||
d2_len += 1 | |||||
elif item in [100, 100, 200]: | |||||
d3_len += 1 | |||||
if idx >= 6: | |||||
raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了") | |||||
assert d1_len == d2_len == d3_len == 30 | |||||
# ds_ratio 为 'pad_to_most' | |||||
dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most') | |||||
d1_len, d2_len, d3_len = 0, 0, 0 | |||||
for idx, batch in enumerate(dl7): | |||||
for item in batch['y'].tolist(): | |||||
if item in [1, 0, 1]: | |||||
d1_len += 1 | |||||
elif item in [20, 10, 10]: | |||||
d2_len += 1 | |||||
elif item in [100, 100, 200]: | |||||
d3_len += 1 | |||||
if idx >= 57: | |||||
raise ValueError(f"ds_ratio 为 'pad_to_most'出错了") | |||||
assert d1_len == d2_len == d3_len == 300 | |||||
# ds_ratio 为 Dict[str, float] | |||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||||
dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||||
d1_len, d2_len, d3_len = 0, 0, 0 | |||||
for idx, batch in enumerate(dl8): | |||||
for item in batch['y'].tolist(): | |||||
if item in [1, 0, 1]: | |||||
d1_len += 1 | |||||
elif item in [20, 10, 10]: | |||||
d2_len += 1 | |||||
elif item in [100, 100, 200]: | |||||
d3_len += 1 | |||||
if idx >= 44: | |||||
raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||||
assert d1_len == 30 | |||||
assert d2_len == 60 | |||||
assert d3_len == 600 | |||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||||
dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||||
d1_len, d2_len, d3_len = 0, 0, 0 | |||||
for idx, batch in enumerate(dl9): | |||||
for item in batch['y'].tolist(): | |||||
if item in [1, 0, 1]: | |||||
d1_len += 1 | |||||
elif item in [20, 10, 10]: | |||||
d2_len += 1 | |||||
elif item in [100, 100, 200]: | |||||
d3_len += 1 | |||||
if idx >= 21: | |||||
raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||||
def test_polling(self): | |||||
datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||||
dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18) | |||||
for idx, batch in enumerate(dl): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
# collate_fn = Callable | |||||
def collate_batch(batch): | |||||
new_batch = {'x': [], 'y': []} | |||||
for ins in batch: | |||||
new_batch['x'].append(ins['x']) | |||||
new_batch['y'].append(ins['y']) | |||||
return new_batch | |||||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18) | |||||
for idx, batch in enumerate(dl1): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]] | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]] | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||||
'd2': Collator(backend='auto').set_pad("x", -2), | |||||
'd3': Collator(backend='auto').set_pad("x", -3)} | |||||
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) | |||||
for idx, batch in enumerate(dl1): | |||||
if idx == 0 or idx == 3: | |||||
assert test_pad_val(batch['x'], val=-1) | |||||
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert test_pad_val(batch['x'], val=-2) | |||||
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert test_pad_val(batch['x'], val=-3) | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
# sampler 为 str | |||||
dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18) | |||||
dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18) | |||||
for idx, batch in enumerate(dl2): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
for idx, batch in enumerate(dl3): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
# sampler 为 Dict | |||||
samplers = {'d1': SequentialSampler(d1), | |||||
'd2': SequentialSampler(d2), | |||||
'd3': RandomSampler(d3)} | |||||
dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18) | |||||
for idx, batch in enumerate(dl4): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or 4 < idx <= 20: | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 20: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
# ds_ratio 为 'truncate_to_least' | |||||
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) | |||||
for idx, batch in enumerate(dl5): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or idx == 5: | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 5: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
# ds_ratio 为 'pad_to_most' | |||||
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) | |||||
for idx, batch in enumerate(dl6): | |||||
if idx % 3 == 0: | |||||
# d1 | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx % 3 == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
if idx % 3 == 2: | |||||
# d3 | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx >= 51: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
# ds_ratio 为 Dict[str, float] | |||||
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||||
dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||||
for idx, batch in enumerate(dl7): | |||||
if idx == 0 or idx == 3: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1 or idx == 4 or idx == 6 or idx == 8: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx == 2 or idx == 5 or idx == 7 or idx > 8: | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 39: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) | |||||
ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||||
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||||
for idx, batch in enumerate(dl8): | |||||
if idx == 0: | |||||
assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||||
assert batch['x'].shape[1] == 4 | |||||
elif idx == 1: | |||||
# d2 | |||||
assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||||
assert batch['x'].shape[1] == 3 | |||||
elif idx > 1: | |||||
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||||
assert batch['x'].shape[1] == 4 | |||||
if idx > 18: | |||||
raise ValueError(f"out of range") | |||||
test_pad_val(batch['x'], val=0) |