| @@ -24,7 +24,7 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||
| class _JittorDataset(Dataset): | |||
| """ | |||
| 对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | |||
| """ | |||
| def __init__(self, dataset) -> None: | |||
| @@ -37,7 +37,7 @@ class _JittorDataset(Dataset): | |||
| item = item.tolist() | |||
| return (item, self.dataset[item]) | |||
| class JittorDataLoader: | |||
| """ | |||
| 提供给 ``jittor`` 框架使用的 ``DataLoader`` 函数,``JittorDataLoader`` 提供了 ``Collator`` 来自动检测 dataset 的每个 field 是否可 pad, | |||
| @@ -2,13 +2,14 @@ __all__ = [ | |||
| 'MixDataLoader' | |||
| ] | |||
| from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||
| from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence, Mapping | |||
| import numpy as np | |||
| from fastNLP.core.dataset import DataSet, Instance | |||
| from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core.collators import Collator | |||
| if _NEED_IMPORT_TORCH: | |||
| from torch.utils.data import DataLoader, Sampler | |||
| @@ -18,12 +19,13 @@ else: | |||
| class _MixDataset: | |||
| """ | |||
| 将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||
| 将所有数据集当成一个混合大数据集来对待, 在 __getitem__() 能根据输入的 idx 来判断属于哪个小数据并返回其 ds_index | |||
| """ | |||
| def __init__(self, datasets: list = None) -> None: | |||
| """ | |||
| :param datasets: 数据集的列表 | |||
| :param datasets: 实现了 __getitem__() 和 __len__() 的对象的序列 | |||
| """ | |||
| self.datasets = datasets | |||
| # 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | |||
| @@ -35,7 +37,7 @@ class _MixDataset: | |||
| def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | |||
| """ | |||
| 根据index索引获取数据 | |||
| 根据index索引获取数据, 能够跟 idx 的范围定位属于哪个小数据并返回 | |||
| :param idx: 整数类型的index或者列表 | |||
| :return: | |||
| @@ -69,8 +71,9 @@ class _MixCollateFn: | |||
| 存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | |||
| """ | |||
| def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||
| auto_collators: Optional[List[Callable]] = None) -> None: | |||
| def __init__(self, collate_fns: Union[List[Callable], Callable]) -> None: | |||
| if isinstance(collate_fns, Sequence): | |||
| self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | |||
| elif callable(collate_fns): | |||
| @@ -78,96 +81,124 @@ class _MixCollateFn: | |||
| else: | |||
| self.collate_fns = lambda idx, lst: lst | |||
| self.collate_fns = collate_fns | |||
| self.auto_collators = auto_collators | |||
| def __call__(self, ins_list: List) -> Dict: | |||
| """ | |||
| 调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | |||
| :param ins_list: | |||
| :return: | |||
| """ | |||
| _ins_list, _ds_index = [], 0 | |||
| for ins, _ds_index in ins_list: | |||
| _ins_list.append(ins) | |||
| # auto_collate先处理 | |||
| if self.auto_collators is not None: | |||
| _ins_list = self.auto_collators[_ds_index](_ins_list) | |||
| _ins_list = self.collate_fns(_ds_index, _ins_list) | |||
| return _ins_list | |||
| class MixDataLoader(DataLoader): | |||
| """ | |||
| 针对一下三种情况提供的MixDataLoader: | |||
| 1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||
| 2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||
| 3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||
| 采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||
| 针对一下四种情况提供的 ``MixDataLoader``, 目前只支持 ``torch`` 框架的版本, 其中 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
| * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
| 接一个的 sample 完所有数据。 | |||
| * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
| 混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
| * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
| 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
| * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
| 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
| sampler, drop_last, ds_ratio 均无效。 | |||
| """ | |||
| def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||
| collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||
| sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||
| def __init__(self, datasets: Dict = None, mode: Union[str, "Sampler"] = 'sequential', | |||
| collate_fn: Union[str, Callable, Dict[str, Callable]] = 'auto', | |||
| sampler: Union[Dict[str, "Sampler"], str, None] = None, | |||
| num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||
| ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||
| pin_memory: bool = True) -> None: | |||
| ds_ratio: Union[None, str, Dict[str, float]] = None, | |||
| pin_memory: bool = False) -> None: | |||
| """ | |||
| :param datasets: dataset的列表 | |||
| :param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||
| 当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||
| 对象,每次迭代返回的是List[int] | |||
| :param collate_fn: 对取得到的数据进行打包的callable函数, | |||
| 当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||
| 当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||
| 其对应关系与datasets位置匹配; | |||
| 当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||
| :param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||
| sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||
| sampler为List[Sampler],则datasets也为List,且一一对应 | |||
| sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||
| :param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
| :param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||
| :param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
| :param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||
| 当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||
| 当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||
| 当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||
| 其大于0,可以超过1,将数据集重采样翻倍即可 | |||
| 当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||
| :param datasets: 实现了 __getitem__() 和 __len__() 对象的序列或者字典。 | |||
| :param mode: mode 控制 ``MixDataLoader`` 运行模式。 mode 的取值范围为 ``['sequential', 'mix', 'polling', "Sampler"]``: | |||
| * 当 mode 为 ``'sequential'`` 时,``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 按照 datasets 数据集序列或者字典的顺序一个 | |||
| 接一个的 sample 完所有数据。 | |||
| * 当 mode 为 ``'mix'`` 时, ``MixDataLoader`` 将 datasets 的序列或者字典视为一个混合大数据集, 然后根据用户输入的 idx 序列随机sample | |||
| 混合数据集 datasets 的数据组成一个 batch 序列返回。 | |||
| * 当 mode 为 ``'polling'`` 时, ``MixDataLoader`` 按照 datasets 数据集的顺序, 先从第一个数据集采样一个 batch 的数据返回, | |||
| 再从第二数据集采样一个 batch 数据返回, 直至最后一个数据集采样一个 batch 数据返回后再从第一个数据采样第二个 batch 数据返回,直至所有的数据集都被轮询的采样完。 | |||
| * 当 mode 为 ``"Sampler"`` 时, 该 Sampler 是实现 __iter__() 的实例化对象, 其功能是每次 iter 时返回一个 batch 序列, 其类型为 List[int]; | |||
| 且 Sampler 必须将输入的 datasets 视为一个混合大数据集, 其 index 范围为 ``0<idx<len(datasets[0])+...+len(datasets[x])``, 然后参数 | |||
| sampler, drop_last, ds_ratio 均无效。 | |||
| :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数。 其取值可以为 ``['auto', Callable, List[Callable], Dict[str, Callable]]``: | |||
| * collate_fn 为 ``'auto'`` 时, ``MixDataLoader`` datasets 序列或者dict 初始化一个 :class: `~fastNLP.core.collators.Collator` 作为其默认值, | |||
| 需要注意的是只有当 datasets 包含的所以 dataset 的数据都为 ``List`` 或者 ``Dict`` 类型时才能使用。否则只能用户自己定义 collate_fn . | |||
| * collate_fn 为 ``Callable`` 时, 该 collate_fn 会被 datasets 序列或者dict 的所有数据所共享。该 Callable 函数应当接受一个 batch 参数作为输入, | |||
| batch 是一个 List 对象且 List 中的每一条数据都是 dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||
| * collate_fn 为 ``Dict[str, Callable]`` 时, datasets 的 key 必须和 callable_fn 的 key 一致。 ``MixDataLoader`` 会将 ``collate_fn[key]`` | |||
| 用到 ``datasets[key]`` 的数据集上。 ``collate_fn[key]`` 是一个 Callable 对象。 | |||
| :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index ,其取值范围为 | |||
| ``[None, str, Dict[str, "Sampler"]]``: | |||
| * sampler 为 ``None`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
| * sampler 为 ``str`` 时, sampler 选择范围为 ``[rand, seq]``。当 sampler 为 ``rand`` 时,``MixDataLoader`` 默认初始化 ``torch`` 的 ``RandomSampler`` | |||
| 作为默认值, 其功能时随机采样 dataset 的下标并返回。 当 sampler 为 ``seq`` 时, ``MixDataLoader`` 默认初始化 ``torch`` 的 ``SequentialSampler`` 作为默认值。其功能时顺序返回 dataset 的下标。 | |||
| * sampler 为 ``Dict[str, "Sampler"]`` 时, ``Sampler`` 为用户定义的实现了 __len__() 和 __iter__() 的实例化对象。 其每次 iter 必须返回一个 int 下标。 | |||
| Dict 的 str 必须和 datasets 的 key 一致。 也即是 ``Dict[str, Sampler] `` 为 datasets 字典的每个 dataset 初始化勒一个 Sampler。 | |||
| :param num_workers: 当 ``num_workers > 0`` 时, ``MixDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快数据处理速度,但同时 | |||
| 也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||
| :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 且 datasets 上所有 dataset 的 batch_size 一致。 | |||
| :param drop_last: 当 ``drop_last=True`` 时,``MixDataLoader`` 会扔掉 datasets 中 每个 dataset 最后一个长度小于 ``batch_size`` 的 batch 数据; | |||
| 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||
| :param ds_ratio: ``ds_ratio`` 是控制 datasets 怎么组成一个混合大数据集的重要参数, 其取值为 ``[None, 'truncate_to_least', 'pad_to_most', List[float], Dict[str, float]]``: | |||
| * ds_ratio 为 ``None``, datasets 数据集序列或字典不进行数据扩充处理。 | |||
| * ds_ratio 为 ``'truncate_to_least'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最断长度 ``mix_len``, 其他数据集会被切断 | |||
| 到最短长度``mix_len``。这种切断不是物理上切断,``MixDataLoader`` 会根据 sampler 不同来采样数据集到指定的最短长度``mix_len``。 | |||
| * ds_ratio 为 ``'pad_to_most'``, datasets 数据集序列或字典会计算得到 datasets序列中 dataset 最大长度 ``max_len``, 其他其他数据集会扩充 | |||
| 到最大长度``mix_len``。这种扩充不是物理上扩充, ``MixDataLoader`` 会根据 sampler 不同来重采样 dataset 到指定的最大长度``max_len``。 | |||
| * ds_ratio 为 ``Dict[str, float]`` 时, datasets 类型也必须为 ``Dict[str, DataSet]``, 其 key 一一对应。 ds_ratio 的 value 是任意大于 0 的浮点数, | |||
| 代表着 datasets 的 value 数据进行扩充或者缩减的倍数。 | |||
| """ | |||
| # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
| if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||
| isinstance(sampler, Dict): | |||
| raise ValueError(f"") | |||
| if isinstance(collate_fn, list): | |||
| if len(collate_fn) != len(datasets): | |||
| raise ValueError("the length of collate_fn != datasets!!") | |||
| if isinstance(sampler, list): | |||
| if len(sampler) != len(datasets): | |||
| raise ValueError("the length of sampler != datasets!!") | |||
| # Dict类型转化为List,以便于_MixCollateFn处理 | |||
| # sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||
| if isinstance(sampler, Dict): | |||
| for key in datasets.keys(): | |||
| if not sampler[key]: | |||
| raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||
| # collate_fn 为 dict,则判断是否与 datasets 的 key 相同 | |||
| if isinstance(collate_fn, Dict): | |||
| if mode == 'mix': | |||
| raise ValueError(f"mode: {mode} do not support collate_fn is Dict, please use callate_fn=Callable or 'auto'") | |||
| for key in datasets.keys(): | |||
| if not collate_fn[key]: | |||
| raise ValueError(f"the key:{key} of datasets is not in collate_fn, where collate_fn is a dict!") | |||
| if isinstance(collate_fn, str) and collate_fn == 'auto': | |||
| date_type = None | |||
| for idx, ds in enumerate(datasets.values()): | |||
| if idx == 0: | |||
| date_type = type(ds[0]) | |||
| if type(ds[0]) != date_type or not (isinstance(ds[0], List) or isinstance(ds[0], Mapping)): | |||
| raise ValueError(f"when you use callate_fn={collate_fn}, all dataset must be list or dict。" | |||
| f"But dataset {idx - 1} data type is {date_type}, dataset {idx} data type is {type(ds[0])}") | |||
| collate_fn = Collator(backend='torch') | |||
| # Dict 类型的 collate_fn 转化为 List,以便于 _MixCollateFn 里面根据 idx 定位 dataset | |||
| if isinstance(collate_fn, Dict): | |||
| collate_fn = [fn for _, fn in collate_fn.items()] | |||
| # 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||
| if isinstance(datasets, Dict): | |||
| dataset = [ds for _, ds in datasets.items()] | |||
| else: | |||
| dataset = datasets | |||
| auto_collators = [] | |||
| for per_ds in dataset: | |||
| if isinstance(per_ds, DataSet): | |||
| auto_collators.append(per_ds.get_collator()) | |||
| else: | |||
| # 如果没有对应的collator就设置一个不做任何操作的collator | |||
| auto_collators.append(lambda x: x) | |||
| # List类型的collate_fn只有两种情况,需要对其进行包裹 | |||
| collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||
| dataset = [ds for _, ds in datasets.items()] | |||
| # 对 collate_fn 进行包裹, 统一处理 collate_fn 不同情况下使用的问题 | |||
| collate_fn = _MixCollateFn(collate_fn) | |||
| if mode == 'sequential': | |||
| batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
| drop_last=drop_last, ds_ratio=ds_ratio) | |||
| @@ -21,9 +21,9 @@ class MixSampler: | |||
| mix_sampler的基类 | |||
| """ | |||
| def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||
| sampler: Union[List["Sampler"], Dict[str, "Sampler"], None, str] = None, | |||
| ds_ratio: Union[str, List[float], Dict[str, float]] = None, | |||
| def __init__(self, dataset: Dict, batch_size: int = None, | |||
| sampler: Union[Dict[str, "Sampler"], None, str] = None, | |||
| ds_ratio: Union[str, Dict[str, float]] = None, | |||
| drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||
| """ | |||
| @@ -32,9 +32,12 @@ class MixSampler: | |||
| :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | |||
| :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | |||
| """ | |||
| # 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
| if isinstance(dataset, Dict) and isinstance(sampler, List): | |||
| raise ValueError(f"{sampler} must be dict") | |||
| # sampler 为 dict,则判断是否与 datasets 的 key 相同 | |||
| if isinstance(sampler, Dict): | |||
| for key in dataset.keys(): | |||
| if not sampler[key]: | |||
| raise ValueError(f"the key:{key} of datasets is not in sampler, where sampler is a dict!") | |||
| if batch_size <= 0: | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got batch_size={}".format(batch_size)) | |||
| @@ -46,15 +49,7 @@ class MixSampler: | |||
| raise ValueError("if rank>=0 and word_size>=0, sampler must be str") | |||
| if sampler is None and (word_size < 0 or rank < 0): | |||
| if isinstance(dataset, List): | |||
| self.sampler = [SequentialSampler(ds) for ds in dataset] | |||
| elif isinstance(dataset, Dict): | |||
| self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||
| elif isinstance(sampler, List): | |||
| if len(sampler) != len(dataset): | |||
| raise ValueError("the length of sampler != the length of sampler") | |||
| self.sampler = sampler | |||
| self.sampler = {name: SequentialSampler(ds) for name, ds in dataset.items()} | |||
| elif isinstance(sampler, Dict): | |||
| self.sampler = sampler | |||
| @@ -68,26 +63,7 @@ class MixSampler: | |||
| # 计算扩展后的大数据集长度total_len和扩展后的单个数据集长度sampler_len | |||
| sampler_lens, total_lens, sampler_index = [], 0, [] | |||
| if isinstance(self.sampler, List): | |||
| if ds_ratio is None: | |||
| sampler_lens = [len(spl) for spl in self.sampler] | |||
| elif ds_ratio == 'pad_to_most': | |||
| sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif ds_ratio == 'truncate_to_least': | |||
| sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif isinstance(ds_ratio, List): | |||
| if not all(item >= 0 for item in ds_ratio): | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got ds_ratio={}".format(ds_ratio)) | |||
| sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, ds_ratio)] | |||
| else: | |||
| raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
| total_lens = sum(sampler_lens) | |||
| elif isinstance(self.sampler, Dict): | |||
| if isinstance(self.sampler, Dict): | |||
| if ds_ratio is None: | |||
| sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
| @@ -100,7 +76,7 @@ class MixSampler: | |||
| sampler_lens = [min(len(spl) for _, spl in self.sampler.items())] * sampler_len | |||
| elif isinstance(ds_ratio, Dict): | |||
| if not all(item >= 0 for item in ds_ratio): | |||
| if not all([item >= 0 for item in ds_ratio.values()]): | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got ds_ratio={}".format(ds_ratio)) | |||
| sampler_lens = [int(len(spl) * ds_ratio[name]) for name, spl in self.sampler.items()] | |||
| @@ -108,7 +84,7 @@ class MixSampler: | |||
| raise ValueError(f"{ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
| total_lens = sum(sampler_lens) | |||
| # sampler为str时候,初始化下移到iter方法中 | |||
| # sampler 为 str 时候,初始化下移到 iter 方法中 | |||
| if len(sampler_lens) > 0: | |||
| sampler_index = [sampler_lens[0]] | |||
| for idx in sampler_lens[1:]: | |||
| @@ -160,75 +136,37 @@ class DopedSampler(MixSampler): | |||
| """ | |||
| 定制给MixDataLoader的BatchSampler,其功能是将传入的datasets的list列表混合采样组成一个个batch返回。 | |||
| """ | |||
| def __init__(self, dataset: Union[List, Dict], batch_size: int = None, | |||
| sampler: Union[List["Sampler"], Dict[str, "Sampler"], str] = None, | |||
| ds_ratio: Union[str, None, List[float], Dict[str, float]] = None, | |||
| def __init__(self, dataset: Dict, batch_size: int = None, | |||
| sampler: Union[Dict[str, "Sampler"], str] = None, | |||
| ds_ratio: Union[str, None, Dict[str, float]] = None, | |||
| drop_last: bool = False, rank: int = -1, word_size: int = -1) -> None: | |||
| super(DopedSampler, self).__init__(dataset=dataset, batch_size=batch_size, | |||
| sampler=sampler, ds_ratio=ds_ratio, | |||
| drop_last=drop_last, rank=rank, word_size=word_size) | |||
| def __iter__(self) -> List[int]: | |||
| # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
| # sampler 为 str, 此时为单机多卡或者单机,可以实现 rand 随机化 | |||
| if isinstance(self.sampler, str): | |||
| if self.sampler == 'seq': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| elif self.sampler == 'rand': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(indices)) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| # 根据给定的ds_ratio计算真正需要处理数据集 | |||
| if isinstance(self.sampler, List): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for spl in self.sampler] | |||
| elif self.ds_ratio == 'pad_to_most': | |||
| sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif self.ds_ratio == 'truncate_to_least': | |||
| sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif isinstance(self.ds_ratio, List): | |||
| if not all(item >= 0 for item in self.ds_ratio): | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got ds_ratio={}".format(self.ds_ratio)) | |||
| sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
| else: | |||
| raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
| total_lens = sum(sampler_lens) | |||
| elif isinstance(self.sampler, Dict): | |||
| if isinstance(self.sampler, Dict): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
| @@ -257,11 +195,11 @@ class DopedSampler(MixSampler): | |||
| sampler_index.append(temp + idx) | |||
| self.num_samplers = sampler_index | |||
| self.len_samplers = total_lens | |||
| # 每个batch的数据, 总的数据量total_index, 每个数据集的samplers | |||
| # 每个 batch 的数据, 总的数据量 total_index , 每个数据集的 samplers | |||
| batch_idx, samplers = [], [] | |||
| # 如果单机则用所有数据,否则采用多卡 | |||
| if self.rank < 0 or self.word_size < 0: | |||
| # 根据sampler长度判断是否使用unsigned int 或者unsigned long | |||
| # 根据 sampler 长度判断是否使用 unsigned int 或者 unsigned long | |||
| if self.len_samplers > 42e8: | |||
| total_index = array.array('L', list(range(self.len_samplers))) | |||
| else: | |||
| @@ -274,15 +212,17 @@ class DopedSampler(MixSampler): | |||
| else: | |||
| total_index = array.array('I', list(range(self.len_samplers))[self.rank::self.word_size]) | |||
| start_idx = 0 | |||
| # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
| for idx, (name, spl) in enumerate(self.sampler.items()): | |||
| end_idx = len(spl) | |||
| samplers.append((iter(spl), name, start_idx)) | |||
| start_idx += end_idx | |||
| # 根据sampler的类型取出每个数据集的sampler | |||
| if isinstance(self.sampler, List): | |||
| sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||
| samplers = [(iter(spl), idx, base_index) | |||
| for idx, (spl, base_index) in enumerate(zip(self.sampler, sampler_base_index))] | |||
| else: | |||
| sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
| samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
| for idx, (name, spl) in enumerate(self.sampler.items())] | |||
| # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
| # samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
| # for idx, (name, spl) in enumerate(self.sampler.items())] | |||
| # 生成随机数 | |||
| np.random.seed(self.epoch) | |||
| np.random.shuffle(total_index) | |||
| @@ -295,7 +235,7 @@ class DopedSampler(MixSampler): | |||
| # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||
| spl = iter(self.sampler[name]) | |||
| batch_idx.append(next(spl) + base_index) | |||
| samplers[name] = (spl, name, base_index) | |||
| samplers[ds_index] = (spl, name, base_index) | |||
| if len(batch_idx) == self.batch_size: | |||
| yield batch_idx | |||
| batch_idx = [] | |||
| @@ -343,63 +283,26 @@ class MixSequentialSampler(MixSampler): | |||
| # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
| if isinstance(self.sampler, str): | |||
| if self.sampler == 'seq': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| elif self.sampler == 'rand': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(indices)) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| # 根据给定的ds_ratio计算真正需要处理数据集 | |||
| if isinstance(self.sampler, List): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for spl in self.sampler] | |||
| elif self.ds_ratio == 'pad_to_most': | |||
| sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif self.ds_ratio == 'truncate_to_least': | |||
| sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif isinstance(self.ds_ratio, List): | |||
| if not all(item >= 0 for item in self.ds_ratio): | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got ds_ratio={}".format(self.ds_ratio)) | |||
| sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
| else: | |||
| raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
| total_lens = sum(sampler_lens) | |||
| elif isinstance(self.sampler, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| # 根据给定的 ds_ratio 算真正需要处理数据集 | |||
| if isinstance(self.sampler, Dict): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
| @@ -430,21 +333,20 @@ class MixSequentialSampler(MixSampler): | |||
| self.len_samplers = total_lens | |||
| batch_idx, total_index, samplers = [], list(range(self.len_samplers)), [] | |||
| if isinstance(self.sampler, List): | |||
| if self.word_size > 0 and self.rank >= 0: | |||
| sampler_base_index = [0] + [len(spl) * self.word_size for spl in self.sampler][:-1] | |||
| else: | |||
| sampler_base_index = [0] + [len(spl) for spl in self.sampler][:-1] | |||
| samplers = [(iter(spl), idx, base_index) for idx, (spl, base_index) in | |||
| enumerate(zip(self.sampler, sampler_base_index))] | |||
| else: | |||
| if self.word_size > 0 and self.rank >= 0: | |||
| sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||
| else: | |||
| sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
| samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
| for idx, (name, spl) in enumerate(self.sampler.items())] | |||
| start_idx = 0 | |||
| # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
| for idx, (name, spl) in enumerate(self.sampler.items()): | |||
| end_idx = len(spl) | |||
| samplers.append((iter(spl), name, start_idx)) | |||
| start_idx += end_idx | |||
| # if self.word_size > 0 and self.rank >= 0: | |||
| # sampler_base_index = [0] + [len(spl) * self.word_size for _, spl in self.sampler.items()][:-1] | |||
| # else: | |||
| # sampler_base_index = [0] + [len(spl) for _, spl in self.sampler.items()][:-1] | |||
| # | |||
| # samplers = [(iter(spl), name, sampler_base_index[idx]) | |||
| # for idx, (name, spl) in enumerate(self.sampler.items())] | |||
| for idx in total_index: | |||
| ds_index = np.searchsorted(self.num_samplers, idx, side='right') | |||
| @@ -455,7 +357,7 @@ class MixSequentialSampler(MixSampler): | |||
| # 重新初始化一个新的sampler,因为不可能为空,故一定不会出现stopIteration | |||
| spl = iter(self.sampler[name]) | |||
| batch_idx.append(next(spl) + base_index) | |||
| samplers[name] = (spl, name, base_index) | |||
| samplers[ds_index] = (spl, name, base_index) | |||
| if len(batch_idx) == self.batch_size: | |||
| yield batch_idx | |||
| batch_idx = [] | |||
| @@ -506,63 +408,26 @@ class PollingSampler(MixSampler): | |||
| # sampler为str, 此时为单机多卡或者单机,可以实现rand随机化 | |||
| if isinstance(self.sampler, str): | |||
| if self.sampler == 'seq': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(list(range(len(per_ds))))) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(list(range(len(per_ds)))) | |||
| elif self.sampler == 'rand': | |||
| if isinstance(self.datasets, List): | |||
| self.sampler = [] | |||
| for per_ds in self.datasets: | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler.append(InnerSampler(indices[self.rank::self.word_size])) | |||
| else: | |||
| self.sampler.append(InnerSampler(indices)) | |||
| elif isinstance(self.datasets, Dict): | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| # 根据给定的ds_ratio计算真正需要处理数据集 | |||
| if isinstance(self.sampler, List): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for spl in self.sampler] | |||
| elif self.ds_ratio == 'pad_to_most': | |||
| sampler_lens = [max(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| self.sampler = {} | |||
| for name, per_ds in self.datasets.items(): | |||
| g = torch.Generator() | |||
| g.manual_seed(self.epoch) | |||
| indices = torch.randperm(len(per_ds), generator=g).tolist() | |||
| if self.word_size >= 0 and self.rank >= 0: | |||
| self.sampler[name] = InnerSampler(indices[self.rank::self.word_size]) | |||
| else: | |||
| self.sampler[name] = InnerSampler(indices) | |||
| elif self.ds_ratio == 'truncate_to_least': | |||
| sampler_lens = [min(len(spl) for spl in self.sampler)] * len(self.sampler) | |||
| elif isinstance(self.ds_ratio, List): | |||
| if not all(item >= 0 for item in self.ds_ratio): | |||
| raise ValueError("batch_size should be a positive integer value, " | |||
| "but got ds_ratio={}".format(self.ds_ratio)) | |||
| sampler_lens = [int(len(spl) * ratio) for spl, ratio in zip(self.sampler, self.ds_ratio)] | |||
| else: | |||
| raise ValueError(f"{self.ds_ratio} must be pad_to_least or truncate_to_least or None or List") | |||
| total_lens = sum(sampler_lens) | |||
| elif isinstance(self.sampler, Dict): | |||
| # 根据给定的ds_ratio计算真正需要处理数据集 | |||
| if isinstance(self.sampler, Dict): | |||
| if self.ds_ratio is None: | |||
| sampler_lens = [len(spl) for _, spl in self.sampler.items()] | |||
| @@ -592,17 +457,15 @@ class PollingSampler(MixSampler): | |||
| self.num_samplers = sampler_index | |||
| self.len_samplers = total_lens | |||
| start_idx, samplers = 0, [] | |||
| if isinstance(self.sampler, List): | |||
| # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
| for sampler_idx, (end_idx, spl) in enumerate(zip(self.num_samplers, self.sampler)): | |||
| samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, sampler_idx)) | |||
| start_idx = end_idx | |||
| else: | |||
| for idx, (name, spl) in enumerate(self.sampler.items()): | |||
| end_idx = self.num_samplers[idx] | |||
| samplers.append((iter(range(start_idx, end_idx)), iter(spl), start_idx, name)) | |||
| start_idx = end_idx | |||
| start_idx, samplers, true_start_idx, true_end_idx = 0, [], 0, 0 | |||
| # (特定数据集需要长度,特定数据集sampler, 特定数据集的基址, 特定sampler的下标) | |||
| for idx, (name, spl) in enumerate(self.sampler.items()): | |||
| end_idx = len(spl) | |||
| true_end_idx = self.num_samplers[idx] | |||
| samplers.append((iter(range(true_start_idx, true_end_idx)), iter(spl), start_idx, name)) | |||
| start_idx += end_idx | |||
| true_start_idx = true_end_idx | |||
| while True: | |||
| # 退出循环 | |||
| @@ -0,0 +1,495 @@ | |||
| import pytest | |||
| from typing import Mapping | |||
| from fastNLP.core.dataloaders import MixDataLoader | |||
| from fastNLP import DataSet | |||
| from fastNLP.core.collators import Collator | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from torch.utils.data import default_collate, SequentialSampler, RandomSampler | |||
| d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
| d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10}) | |||
| d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) | |||
| def test_pad_val(tensor, val=0): | |||
| if isinstance(tensor, torch.Tensor): | |||
| tensor = tensor.tolist() | |||
| for item in tensor: | |||
| if item[-1] > 0: | |||
| continue | |||
| elif item[-1] != val: | |||
| return False | |||
| return True | |||
| class TestMixDataLoader: | |||
| def test_sequential_init(self): | |||
| datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
| # drop_last = True, collate_fn = 'auto | |||
| dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True) | |||
| for idx, batch in enumerate(dl): | |||
| if idx == 0: | |||
| # d1 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| # d2 | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| # collate_fn = Callable | |||
| def collate_batch(batch): | |||
| new_batch = {'x': [], 'y': []} | |||
| for ins in batch: | |||
| new_batch['x'].append(ins['x']) | |||
| new_batch['y'].append(ins['y']) | |||
| return new_batch | |||
| dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True) | |||
| for idx, batch in enumerate(dl1): | |||
| if idx == 0: | |||
| # d1 | |||
| assert [1, 2] in batch['x'] | |||
| if idx == 1: | |||
| # d2 | |||
| assert [101, 201] in batch['x'] | |||
| if idx > 1: | |||
| # d3 | |||
| assert [1000, 2000] in batch['x'] | |||
| assert 'x' in batch and 'y' in batch | |||
| collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
| 'd2': Collator(backend='auto').set_pad("x", -2), | |||
| 'd3': Collator(backend='auto').set_pad("x", -3)} | |||
| dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) | |||
| for idx, batch in enumerate(dl2): | |||
| if idx == 0: | |||
| assert test_pad_val(batch['x'], val=-1) | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| assert test_pad_val(batch['x'], val=-2) | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx > 1: | |||
| assert test_pad_val(batch['x'], val=-3) | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| # sampler 为 str | |||
| dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True) | |||
| dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True) | |||
| for idx, batch in enumerate(dl3): | |||
| if idx == 0: | |||
| # d1 | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx == 2: | |||
| # d3 | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| for idx, batch in enumerate(dl4): | |||
| if idx == 0: | |||
| # d1 | |||
| assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx == 2: | |||
| # d3 | |||
| assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| # sampler 为 Dict | |||
| samplers = {'d1': SequentialSampler(d1), | |||
| 'd2': SequentialSampler(d2), | |||
| 'd3': RandomSampler(d3)} | |||
| dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True) | |||
| for idx, batch in enumerate(dl5): | |||
| if idx == 0: | |||
| # d1 | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx > 1: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'truncate_to_least' | |||
| dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) | |||
| for idx, batch in enumerate(dl6): | |||
| if idx == 0: | |||
| # d1 | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if idx == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if idx == 2: | |||
| # d3 | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx > 2: | |||
| raise ValueError(f"ds_ratio: 'truncate_to_least' error") | |||
| # ds_ratio 为 'pad_to_most' | |||
| dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True) | |||
| for idx, batch in enumerate(dl7): | |||
| if idx < 18: | |||
| # d1 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if 18 <= idx < 36: | |||
| # d2 | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if 36 <= idx < 54: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 54: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| # ds_ratio 为 Dict[str, float] | |||
| ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
| dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||
| for idx, batch in enumerate(dl8): | |||
| if idx < 1: | |||
| # d1 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| if 1 <= idx < 4: | |||
| # d2 | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if 4 <= idx < 41: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 41: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
| dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True) | |||
| for idx, batch in enumerate(dl9): | |||
| if idx < 1: | |||
| # d2 | |||
| assert batch['x'].shape == torch.Size([16, 3]) | |||
| if 1 <= idx < 19: | |||
| # d3 | |||
| assert batch['x'].shape == torch.Size([16, 4]) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 19: | |||
| raise ValueError(f"ds_ratio: 'pad_to_most' error") | |||
| def test_mix(self): | |||
| datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
| dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) | |||
| for idx, batch in enumerate(dl): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| # collate_fn = Callable | |||
| def collate_batch(batch): | |||
| new_batch = {'x': [], 'y': []} | |||
| for ins in batch: | |||
| new_batch['x'].append(ins['x']) | |||
| new_batch['y'].append(ins['y']) | |||
| return new_batch | |||
| dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) | |||
| for idx, batch in enumerate(dl1): | |||
| assert isinstance(batch['x'], list) | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
| 'd2': Collator(backend='auto').set_pad("x", -2), | |||
| 'd3': Collator(backend='auto').set_pad("x", -3)} | |||
| with pytest.raises(ValueError): | |||
| MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns) | |||
| # sampler 为 str | |||
| dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) | |||
| for idx, batch in enumerate(dl3): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) | |||
| for idx, batch in enumerate(dl4): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| # sampler 为 Dict | |||
| samplers = {'d1': SequentialSampler(d1), | |||
| 'd2': SequentialSampler(d2), | |||
| 'd3': RandomSampler(d3)} | |||
| dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) | |||
| for idx, batch in enumerate(dl5): | |||
| assert test_pad_val(batch['x'], val=0) | |||
| if idx >= 22: | |||
| raise ValueError(f"out of range") | |||
| # ds_ratio 为 'truncate_to_least' | |||
| dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least') | |||
| d1_len, d2_len, d3_len = 0, 0, 0 | |||
| for idx, batch in enumerate(dl6): | |||
| for item in batch['y'].tolist(): | |||
| if item in [1, 0, 1]: | |||
| d1_len += 1 | |||
| elif item in [20, 10, 10]: | |||
| d2_len += 1 | |||
| elif item in [100, 100, 200]: | |||
| d3_len += 1 | |||
| if idx >= 6: | |||
| raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了") | |||
| assert d1_len == d2_len == d3_len == 30 | |||
| # ds_ratio 为 'pad_to_most' | |||
| dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most') | |||
| d1_len, d2_len, d3_len = 0, 0, 0 | |||
| for idx, batch in enumerate(dl7): | |||
| for item in batch['y'].tolist(): | |||
| if item in [1, 0, 1]: | |||
| d1_len += 1 | |||
| elif item in [20, 10, 10]: | |||
| d2_len += 1 | |||
| elif item in [100, 100, 200]: | |||
| d3_len += 1 | |||
| if idx >= 57: | |||
| raise ValueError(f"ds_ratio 为 'pad_to_most'出错了") | |||
| assert d1_len == d2_len == d3_len == 300 | |||
| # ds_ratio 为 Dict[str, float] | |||
| ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
| dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||
| d1_len, d2_len, d3_len = 0, 0, 0 | |||
| for idx, batch in enumerate(dl8): | |||
| for item in batch['y'].tolist(): | |||
| if item in [1, 0, 1]: | |||
| d1_len += 1 | |||
| elif item in [20, 10, 10]: | |||
| d2_len += 1 | |||
| elif item in [100, 100, 200]: | |||
| d3_len += 1 | |||
| if idx >= 44: | |||
| raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||
| assert d1_len == 30 | |||
| assert d2_len == 60 | |||
| assert d3_len == 600 | |||
| ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
| dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio) | |||
| d1_len, d2_len, d3_len = 0, 0, 0 | |||
| for idx, batch in enumerate(dl9): | |||
| for item in batch['y'].tolist(): | |||
| if item in [1, 0, 1]: | |||
| d1_len += 1 | |||
| elif item in [20, 10, 10]: | |||
| d2_len += 1 | |||
| elif item in [100, 100, 200]: | |||
| d3_len += 1 | |||
| if idx >= 21: | |||
| raise ValueError(f"ds_ratio 为 'Dict'出错了") | |||
| def test_polling(self): | |||
| datasets = {'d1': d1, 'd2': d2, 'd3': d3} | |||
| dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18) | |||
| for idx, batch in enumerate(dl): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| # collate_fn = Callable | |||
| def collate_batch(batch): | |||
| new_batch = {'x': [], 'y': []} | |||
| for ins in batch: | |||
| new_batch['x'].append(ins['x']) | |||
| new_batch['y'].append(ins['y']) | |||
| return new_batch | |||
| dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18) | |||
| for idx, batch in enumerate(dl1): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]] | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]] | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]] | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1), | |||
| 'd2': Collator(backend='auto').set_pad("x", -2), | |||
| 'd3': Collator(backend='auto').set_pad("x", -3)} | |||
| dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) | |||
| for idx, batch in enumerate(dl1): | |||
| if idx == 0 or idx == 3: | |||
| assert test_pad_val(batch['x'], val=-1) | |||
| assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert test_pad_val(batch['x'], val=-2) | |||
| assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert test_pad_val(batch['x'], val=-3) | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| # sampler 为 str | |||
| dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18) | |||
| dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18) | |||
| for idx, batch in enumerate(dl2): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| for idx, batch in enumerate(dl3): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| # sampler 为 Dict | |||
| samplers = {'d1': SequentialSampler(d1), | |||
| 'd2': SequentialSampler(d2), | |||
| 'd3': RandomSampler(d3)} | |||
| dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18) | |||
| for idx, batch in enumerate(dl4): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or 4 < idx <= 20: | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 20: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'truncate_to_least' | |||
| dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) | |||
| for idx, batch in enumerate(dl5): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or idx == 5: | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 5: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 'pad_to_most' | |||
| dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) | |||
| for idx, batch in enumerate(dl6): | |||
| if idx % 3 == 0: | |||
| # d1 | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx % 3 == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| if idx % 3 == 2: | |||
| # d3 | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx >= 51: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| # ds_ratio 为 Dict[str, float] | |||
| ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} | |||
| dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||
| for idx, batch in enumerate(dl7): | |||
| if idx == 0 or idx == 3: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1 or idx == 4 or idx == 6 or idx == 8: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx == 2 or idx == 5 or idx == 7 or idx > 8: | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 39: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||
| ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} | |||
| dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) | |||
| for idx, batch in enumerate(dl8): | |||
| if idx == 0: | |||
| assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]] | |||
| assert batch['x'].shape[1] == 4 | |||
| elif idx == 1: | |||
| # d2 | |||
| assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]] | |||
| assert batch['x'].shape[1] == 3 | |||
| elif idx > 1: | |||
| assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] | |||
| assert batch['x'].shape[1] == 4 | |||
| if idx > 18: | |||
| raise ValueError(f"out of range") | |||
| test_pad_val(batch['x'], val=0) | |||