From b76c6816bec82c78d7b8e058b00151d02c178bf9 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Wed, 18 May 2022 20:52:03 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E5=A2=9E=E5=8A=A0fdl=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/dataloaders/jittor_dataloader/fdl.py | 111 ++++++++++++- .../core/dataloaders/paddle_dataloader/fdl.py | 89 ++++++++-- .../core/dataloaders/torch_dataloader/fdl.py | 152 +++++++++++------- 3 files changed, 282 insertions(+), 70 deletions(-) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 6e0ef8f5..0705866d 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -3,7 +3,7 @@ __all__ = [ 'prepare_jittor_dataloader' ] -from typing import Callable, Optional, List, Union +from typing import Callable, Optional, List, Union, Dict, Sequence from copy import deepcopy import numpy as np @@ -185,5 +185,110 @@ class JittorDataLoader: return self.cur_batch_indices -def prepare_jittor_dataloader(): - ... +def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: bool = True, + drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, + stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, + collate_fn: Union[None, str, Callable] = "auto", + non_train_batch_size: int = 16) \ + -> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: + """ + prepare_jittor_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``,具体如下: + + * 当ds_or_db为Dataset时,prepare_jittor_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个 + JittorDataLoader并返回。 + * 当ds_or_db为FastNLP的DataBundle时,prepare_jittor_dataloader会遍历所有的dataset并根据其name实例化不同的JittorDataLoader, + 当name中包含'train'字符串时,prepare_jittor_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终根据name:JittorDataLoader组成一个Dict[name, JittorDataLoader] + 的数据返回。 + * 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_jittor_dataloader会遍历所有的dataset并根据其name实例化不同的JittorDataLoader, + 当name中包含'train'字符串时,prepare_jittor_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终根据name:JittorDataLoader组成一个Dict[name, JittorDataLoader] + 的数据返回。 + * 当ds_or_db为Sequence[Dataset]数据类型时, prepare_jittor_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为 + 其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化JittorDataLoader。最终 + 将所有JittorDataLoader组成Sequence[JittorDataLoader]返回。 + + :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``. + :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 + :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param shuffle: 是否打乱数据集 + :param drop_last: 是否去掉最后一个不符合``batch_size``的数据 + :param num_workers: 进程的数量,当``num_workers=0``时不开启多进程 + :param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 + :param stop_grad: + :param keep_numpy_array: 返回的数据是``np.array`类`型而不是``jittor.array``类型,默认为``False`` + :param endless: 是否让``JittorDataLoader``无限返回数据,也就是将dataset循环使用使得返回数据是没有限制的。默认为``False``. + :param collate_fn: 用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. + + * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; + 第二点注意的是此时``JittorDataLoader``会调用默认的`callate_batch`函数对sampler到的数据进行简单打包,组成一个batch返回。` + * ``callate_fn="auto"``时,``JittorDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, + 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, + 可以调用``set_ignore``方法忽略某个field。 + * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``JittorDataLoader``会调用传进来的callable函数对 + 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 + + :return: 返回数据类型为Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader其中之一,根据输入ds_or_db变化而变化。 + """ + from fastNLP.io.data_bundle import DataBundle + if isinstance(ds_or_db, Dataset): + dl = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, + collate_fn=collate_fn) + return dl + elif isinstance(ds_or_db, DataBundle): + dl_bundle = {} + for name, ds in ds_or_db.iter_datasets(): + if 'train' in name: + dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, + buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, + endless=endless, + collate_fn=collate_fn) + else: + dl_bundle[name] = JittorDataLoader(ds_or_db, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, + buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, + endless=endless, + collate_fn=collate_fn) + return dl_bundle + elif isinstance(ds_or_db, Sequence): + ds_seq = [] + for idx, ds in enumerate(ds_or_db): + if idx > 0: + train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size + dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, + collate_fn=collate_fn) + ds_seq.append(dl) + return ds_seq + + elif isinstance(ds_or_db, Dict): + ds_dict = {} + for name, ds in ds_or_db.items(): + if 'train' in name: + dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, + collate_fn=collate_fn) + else: + dl = JittorDataLoader(ds_or_db, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + shuffle=shuffle, + drop_last=drop_last, num_workers=num_workers, + buffer_size=buffer_size, + stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, + endless=endless, + collate_fn=collate_fn) + ds_dict[name] = dl + return ds_dict + else: + raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 36f6588b..db7bc47e 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -195,8 +195,8 @@ class PaddleDataLoader(DataLoader): field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 无意义。 :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'paddle', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, paddle.Tensor, paddle.Var 类型。若 pad_val 为 None ,该值无意义 。 :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch 形式,输出将被直接作为结果输出。 @@ -253,17 +253,81 @@ class PaddleDataLoader(DataLoader): def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, return_list: bool = True, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, - batch_size: int = 1, shuffle: bool = False, + train_batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False, non_train_batch_size: int = 16) \ -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: + """ + prepare_paddle_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``,具体如下: + + * 当ds_or_db为Dataset时,prepare_paddle_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个 + paddleDataLoader并返回。 + * 当ds_or_db为FastNLP的DataBundle时,prepare_paddle_dataloader会遍历所有的dataset并根据其name实例化不同的paddleDataLoader, + 当name中包含'train'字符串时,prepare_paddle_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终根据name:paddleDataLoader组成一个Dict[name, paddleDataLoader] + 的数据返回。 + * 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_paddle_dataloader会遍历所有的dataset并根据其name实例化不同的paddleDataLoader, + 当name中包含'train'字符串时,prepare_paddle_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终根据name:paddleDataLoader组成一个Dict[name, paddleDataLoader] + 的数据返回。 + * 当ds_or_db为Sequence[Dataset]数据类型时, prepare_paddle_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为 + 其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化paddleDataLoader。最终 + 将所有paddleDataLoader组成Sequence[paddleDataLoader]返回。 + + :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``. + :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 + :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. + The Tensors should be created by :code:`paddle.static.data()`. + :attr:`feed_list` must be set if :attr:`return_list` is + False. Default None. + :param places: (list(Place)|tuple(Place)|list(str)|optional): a list of Place, + to put data onto, :attr:`places` can be None, if + :attr:`places` is None, default place(CPUPlace or CUDAPlace(0)) + will be used. Default None. If ``places`` is list of string, + the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, + where ``x`` is the index of the GPUs. + :param return_list: whether the return value on each device is + presented as a list. If :attr:`return_list=False`, the return + value on each device would be a dict of str -> Tensor, where + the key of the dict is the name of each fed Tensors. If + :attr:`return_list=True`, the return value on each device would + be a list(Tensor). :attr:`return_list` can only be True + in dynamic graph mode. Default True. + :param batch_sampler: 实现了``__iter__``和``__len__``方法的实例化对象,它的功能是根据dataset生成数据indices并组成一个batch数据。 + :param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。 + :param drop_last: 当``drop_last=True``时,``PaddleDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; + 若``drop_last=False``, 则什么也不做。 + :param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. + + * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; + 第二点注意的是此时``PaddleDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` + * ``callate_fn="auto"``时,``PaddleDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, + 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, + 可以调用``set_ignore``方法忽略某个field。 + * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``PaddleDataLoader``会调用传进来的callable函数对 + 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 + + :param num_workers: 开启多进程的数量,当``num_workers=0``时不开启多进程 + :param use_buffer_reader: 是否开启buffer_reader。如果``use_buffer_reader=True``,那么``PaddleDataLoader``将会异步的预取下一个batch的 + 数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是``True``。如果``use_buffer_reader=False``,那么什么也不错 + :param use_shared_memory: 是否使用共享内存。当``use_shared_memory=True``时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 + 共享空间足够大时。(例如Linux上的/dev/shm/空间足够大)共享内存仅在多进程模式(num_workers>0)下生效。 + :param timeout: 从子进程的输出队列获取数据的超时值 + :param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 + :param persistent_workers: + + :return: + """ from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, Dataset): dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) @@ -274,7 +338,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, if 'train' in name: dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, @@ -284,7 +348,8 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, else: dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=non_train_batch_size, + batch_sampler=batch_sampler, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, @@ -294,9 +359,11 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, return dl_bundle elif isinstance(ds_or_db, Sequence): ds_seq = [] - for ds in ds_or_db: + for idx, ds in enumerate(ds_or_db): + if idx > 0: + train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) @@ -308,14 +375,16 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, for name, ds in ds_or_db.items(): if 'train' in name: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) else: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index f5e4af97..7c54bed7 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -43,11 +43,11 @@ class _FDataSet: class TorchDataLoader(DataLoader): """ - 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 - 提供的方法调节设置collate_fn的若干参数。 + 提供给``torch``框架使用的``DataLoader``函数,``TorchDataLoader``提供了``Collator``的功能,用户可以通过设置``collate_fn="auto"``来 + 使用,并可以配套使用``set_pad``和``set_ignore``方法设置p``ad_val``和忽略某个field的pad操作。 """ - def __init__(self, dataset, batch_size: int = 1, + def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', @@ -60,18 +60,30 @@ class TorchDataLoader(DataLoader): :param dataset: 实现了__getitem__和__len__的数据容器 :param batch_size: 批次大小,当batch_sampler为None生效 :param shuffle: 是否打乱数据集 - :param sampler: sampler实例化对象 - :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 - :param num_workers: 进程的数量,当num_worker=0时不开启多进程 - :param collate_fn: [None, 'auto', callable] 对取得到的数据进行打包的callable函数 - :param pin_memory: - :param drop_last: 是否去掉最后一个不符合batch_size的数据 - :param timeout: - :param worker_init_fn: - :param multiprocessing_context: - :param generator: - :param prefetch_factor: - :param persistent_workers: + :param sampler: 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 + :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, + 当其不为None时,bacth_size,sampler,shuffle均无效。 + :param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程 + :param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. + + * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; + 第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` + * ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, + 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, + 可以调用``set_ignore``方法忽略某个field。 + * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对 + 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 + + :param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。 + :param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; + 若``drop_last=False``, 则什么也不做。 + :param timeout: 从子进程的输出队列获取数据的超时值 + :param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 + :param multiprocessing_context: 多进程的上下文环境 + :param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed`` + :param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2. + :param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False + """ if isinstance(dataset, DataSet) and collate_fn is None: raise ValueError("When use FastNLP DataSet, collate_fn must be not None") @@ -111,9 +123,6 @@ class TorchDataLoader(DataLoader): self.cur_batch_indices = None def __iter__(self): - # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 - # if len(self._collate_fn.get_collators()) == 0: - # self._collate_fn.add_collator(self.collate_fn) self.collate_fn = indice_collate_wrapper(self.collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices @@ -132,12 +141,12 @@ class TorchDataLoader(DataLoader): field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 无意义。 :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 + :param backend: 可选['raw', 'numpy', 'torch', 'torch', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, + torch.Tensor, torch.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch 形式,输出将被直接作为结果输出。 - :return: 返回 Collator 自身 + :return: 返回 Collator """ collator = self._get_collator() if isinstance(collator, Collator): @@ -187,11 +196,10 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices - def prepare_torch_dataloader(ds_or_db, - batch_size: int = 16, + train_batch_size: int = 16, shuffle: bool = False, - sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, @@ -201,32 +209,58 @@ def prepare_torch_dataloader(ds_or_db, non_train_batch_size: int = 16) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ - 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 - - :param input_fields: - :param ds_or_db: dataset或者data_bundle - :param batch_size: 批次大小,当batch_sampler为None生效 + prepare_torch_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``,具体如下: + + * 当ds_or_db为Dataset时,prepare_torch_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个 + torchDataLoader并返回。 + * 当ds_or_db为FastNLP的DataBundle时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader, + 当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader] + 的数据返回。 + * 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader, + 当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 + 的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader] + 的数据返回。 + * 当ds_or_db为Sequence[Dataset]数据类型时, prepare_torch_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为 + 其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终 + 将所有torchDataLoader组成Sequence[torchDataLoader]返回。 + + :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, + Sequence[Dataset], Dict[name, Dataset]]``. :param shuffle: 是否打乱数据集 - :param sampler: sampler实例化对象 - :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 - :param num_workers: 进程的数量,当num_worker=0时不开启多进程 - :param collate_fn: ['auto', None, callable]对取得到的数据进行打包的callable函数 - :param pin_memory: - :param drop_last: 是否去掉最后一个不符合batch_size的数据 - :param timeout: - :param worker_init_fn: - :param multiprocessing_context: - :param generator: - :param prefetch_factor: - :param persistent_workers: - :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler - :param non_train_batch_size: + :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 + :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 + :param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 + :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, + 当其不为None时,bacth_size,sampler,shuffle均无效。 + :param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程 + :param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. + + * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; + 第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` + * ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, + 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, + 可以调用``set_ignore``方法忽略某个field。 + * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对 + 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 + + :param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。 + :param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; + 若``drop_last=False``, 则什么也不做。 + :param timeout: 从子进程的输出队列获取数据的超时值 + :param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 + :param multiprocessing_context: 多进程的上下文环境 + :param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed`` + :param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2. + :param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False """ from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): - dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, + shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -238,8 +272,8 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, + shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -247,8 +281,10 @@ def prepare_torch_dataloader(ds_or_db, persistent_workers=persistent_workers, ) else: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, + dl_bundle[name] = TorchDataLoader(dataset=ds, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + shuffle=shuffle, + sampler=non_train_sampler if non_train_sampler else train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -262,11 +298,11 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = [] for idx, ds in enumerate(ds_or_db): if idx > 0: - batch_size = non_train_batch_size if non_train_batch_size else batch_size - sampler = non_train_sampler if non_train_sampler else sampler + train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size + train_sampler = non_train_sampler if non_train_sampler else train_sampler dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + TorchDataLoader(dataset=ds, batch_size=train_batch_size, + shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -279,8 +315,8 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.items(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, + shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -288,8 +324,10 @@ def prepare_torch_dataloader(ds_or_db, persistent_workers=persistent_workers, ) else: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, - shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, + dl_bundle[name] = TorchDataLoader(dataset=ds, + batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + shuffle=shuffle, + sampler=non_train_sampler if non_train_sampler else train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, From 696f6b89f830baffefbeea1c71d1c354a057090e Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 18 May 2022 21:49:47 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=8F=82=E6=95=B0train?= =?UTF-8?q?=5Fbatch=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/dataloaders/jittor_dataloader/fdl.py | 21 ++++++++++--------- .../core/dataloaders/paddle_dataloader/fdl.py | 21 ++++++++++--------- .../core/dataloaders/torch_dataloader/fdl.py | 21 ++++++++++--------- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 0705866d..cee3cf3d 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -185,7 +185,7 @@ class JittorDataLoader: return self.cur_batch_indices -def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: bool = True, +def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", @@ -211,8 +211,9 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, Sequence[Dataset], Dict[name, Dataset]]``. - :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 - :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param batch_size: batch 的大小。 + :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 + 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 :param shuffle: 是否打乱数据集 :param drop_last: 是否去掉最后一个不符合``batch_size``的数据 :param num_workers: 进程的数量,当``num_workers=0``时不开启多进程 @@ -234,7 +235,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo """ from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, Dataset): - dl = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, + dl = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) @@ -243,7 +244,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=train_batch_size, shuffle=shuffle, + dl_bundle[name] = JittorDataLoader(ds_or_db, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, @@ -251,7 +252,7 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo collate_fn=collate_fn) else: dl_bundle[name] = JittorDataLoader(ds_or_db, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, @@ -263,8 +264,8 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo ds_seq = [] for idx, ds in enumerate(ds_or_db): if idx > 0: - train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size - dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, + batch_size = non_train_batch_size if non_train_batch_size else batch_size + dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) @@ -275,13 +276,13 @@ def prepare_jittor_dataloader(ds_or_db, train_batch_size: int = 16, shuffle: boo ds_dict = {} for name, ds in ds_or_db.items(): if 'train' in name: - dl = JittorDataLoader(ds, batch_size=train_batch_size, shuffle=shuffle, + dl = JittorDataLoader(ds, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, keep_numpy_array=keep_numpy_array, endless=endless, collate_fn=collate_fn) else: dl = JittorDataLoader(ds_or_db, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers, buffer_size=buffer_size, diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index db7bc47e..d37f0ed7 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -253,7 +253,7 @@ class PaddleDataLoader(DataLoader): def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, return_list: bool = True, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, - train_batch_size: int = 16, shuffle: bool = False, + batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, @@ -280,8 +280,9 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, Sequence[Dataset], Dict[name, Dataset]]``. - :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 - :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param batch_size: batch 的大小。 + :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 + 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 :param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. The Tensors should be created by :code:`paddle.static.data()`. :attr:`feed_list` must be set if :attr:`return_list` is @@ -327,7 +328,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, from fastNLP.io.data_bundle import DataBundle if isinstance(ds_or_db, Dataset): dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) @@ -338,7 +339,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, if 'train' in name: dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=train_batch_size, + batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, @@ -349,7 +350,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, @@ -361,9 +362,9 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, ds_seq = [] for idx, ds in enumerate(ds_or_db): if idx > 0: - train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size + batch_size = non_train_batch_size if non_train_batch_size else batch_size dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) @@ -375,7 +376,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, for name, ds in ds_or_db.items(): if 'train' in name: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, - batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, + batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, timeout=timeout, worker_init_fn=worker_init_fn, @@ -383,7 +384,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, else: dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, batch_sampler=batch_sampler, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 7c54bed7..99faec7e 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -197,7 +197,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db, - train_batch_size: int = 16, + batch_size: int = 16, shuffle: bool = False, train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, @@ -229,8 +229,9 @@ def prepare_torch_dataloader(ds_or_db, :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, Sequence[Dataset], Dict[name, Dataset]]``. :param shuffle: 是否打乱数据集 - :param train_batch_size: 'train'数据集使用的batch_size,跟non_train_batch_size是互斥的。 - :param non_train_batch_size: 非'train'数据使用batch_size,跟train_batch_size是互斥的。 + :param batch_size: batch 的大小。 + :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 + 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 :param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 :param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, @@ -259,7 +260,7 @@ def prepare_torch_dataloader(ds_or_db, from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): - dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, + dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -272,7 +273,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -282,7 +283,7 @@ def prepare_torch_dataloader(ds_or_db, ) else: dl_bundle[name] = TorchDataLoader(dataset=ds, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else train_sampler, batch_sampler=batch_sampler, @@ -298,10 +299,10 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = [] for idx, ds in enumerate(ds_or_db): if idx > 0: - train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size + batch_size = non_train_batch_size if non_train_batch_size else batch_size train_sampler = non_train_sampler if non_train_sampler else train_sampler dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=train_batch_size, + TorchDataLoader(dataset=ds, batch_size=batch_size, shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -315,7 +316,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle = {} for name, ds in ds_or_db.items(): if 'train' in name: - dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, + dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -325,7 +326,7 @@ def prepare_torch_dataloader(ds_or_db, ) else: dl_bundle[name] = TorchDataLoader(dataset=ds, - batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, + batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else train_sampler, batch_sampler=batch_sampler, From d2672c62b1134b60117b507e60bae21bf9088810 Mon Sep 17 00:00:00 2001 From: MorningForest <2297662686@qq.com> Date: Wed, 18 May 2022 21:54:12 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9fdl=20train=5Fsampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 99faec7e..9818ab39 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -199,7 +199,7 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, - train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, + sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', pin_memory: bool = False, drop_last: bool = False, @@ -261,7 +261,7 @@ def prepare_torch_dataloader(ds_or_db, from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -274,7 +274,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.iter_datasets(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -285,7 +285,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, - sampler=non_train_sampler if non_train_sampler else train_sampler, + sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, @@ -300,10 +300,10 @@ def prepare_torch_dataloader(ds_or_db, for idx, ds in enumerate(ds_or_db): if idx > 0: batch_size = non_train_batch_size if non_train_batch_size else batch_size - train_sampler = non_train_sampler if non_train_sampler else train_sampler + sampler = non_train_sampler if non_train_sampler else sampler dl_bundle.append( TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -317,7 +317,7 @@ def prepare_torch_dataloader(ds_or_db, for name, ds in ds_or_db.items(): if 'train' in name: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, @@ -328,7 +328,7 @@ def prepare_torch_dataloader(ds_or_db, dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, shuffle=shuffle, - sampler=non_train_sampler if non_train_sampler else train_sampler, + sampler=non_train_sampler if non_train_sampler else sampler, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,