| @@ -21,8 +21,13 @@ else: | |||||
| class _FDataSet: | class _FDataSet: | ||||
| """ | """ | ||||
| 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset | |||||
| 中调用dataset的方法 | |||||
| 提供给 ``TorchDataLoader`` 使用的 warp 类,其功能是对 dataset 进行封装,wrap 修改 dataset 的 __getitem__ 函数,增加返回 | |||||
| 数据的下标 idx 。 | |||||
| ..note:: | |||||
| 需要注意的是传入 ``__init__`` 的 dataset 需要实现 __getattribute__ 方法才能在 _FDataset 实例化对象中调用 dataset 的方法 | |||||
| """ | """ | ||||
| def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
| @@ -43,8 +48,17 @@ class _FDataSet: | |||||
| class TorchDataLoader(DataLoader): | class TorchDataLoader(DataLoader): | ||||
| """ | """ | ||||
| 提供给``torch``框架使用的``DataLoader``函数,``TorchDataLoader``提供了``Collator``的功能,用户可以通过设置``collate_fn="auto"``来 | |||||
| 使用,并可以配套使用``set_pad``和``set_ignore``方法设置p``ad_val``和忽略某个field的pad操作。 | |||||
| 提供给 ``torch`` 框架使用的 ``DataLoader`` 函数,``TorchDataLoader`` 提供了 ``Collator`` 来自动检测dataset的每个field是否可pad, | |||||
| 若是可 pad 的 field 则自动 pad 到相同长度,否则只会将相同 field 的数据收集组成一个 batch 返回。 | |||||
| 具体详见 :class:`~fastNLP.core.collators.Collator`;用户通过 callte_fn 来控制是否使用该功能, collate_fn 只能为 ``['auto', None, Callable]``三种取值。 | |||||
| * callate_fn 为 ``'auto'`` 时,``TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的取值。 | |||||
| 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
| * callate_fn 为 ``None`` 时, ``TorchDataLoadr`` 默认使用 torch DataLoader 自带的 collate_fn | |||||
| * collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
| dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
| """ | """ | ||||
| def __init__(self, dataset, batch_size: int = 16, | def __init__(self, dataset, batch_size: int = 16, | ||||
| @@ -57,32 +71,34 @@ class TorchDataLoader(DataLoader): | |||||
| persistent_workers: bool = False, **kwargs) -> None: | persistent_workers: bool = False, **kwargs) -> None: | ||||
| """ | """ | ||||
| :param dataset: 实现了__getitem__和__len__的数据容器 | |||||
| :param batch_size: 批次大小,当batch_sampler为None生效 | |||||
| :param shuffle: 是否打乱数据集 | |||||
| :param sampler: 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | |||||
| :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, | |||||
| 当其不为None时,bacth_size,sampler,shuffle均无效。 | |||||
| :param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程 | |||||
| :param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. | |||||
| * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; | |||||
| 第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` | |||||
| * ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, | |||||
| 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, | |||||
| 可以调用``set_ignore``方法忽略某个field。 | |||||
| * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对 | |||||
| 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 | |||||
| :param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。 | |||||
| :param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; | |||||
| 若``drop_last=False``, 则什么也不做。 | |||||
| :param timeout: 从子进程的输出队列获取数据的超时值 | |||||
| :param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 | |||||
| :param dataset: 实现了 __getitem__() 和 __len__() 的对象。 | |||||
| :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
| :param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
| :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
| 默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
| :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
| dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
| :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
| 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
| :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
| * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
| ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
| :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
| * callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
| 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
| * `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
| dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
| :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
| :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
| 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
| :param timeout: 子进程的输出队列获取数据的超时值 | |||||
| :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
| :param multiprocessing_context: 多进程的上下文环境 | :param multiprocessing_context: 多进程的上下文环境 | ||||
| :param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed`` | |||||
| :param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2. | |||||
| :param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False | |||||
| :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||||
| :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
| :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
| """ | """ | ||||
| if isinstance(dataset, DataSet) and collate_fn is None: | if isinstance(dataset, DataSet) and collate_fn is None: | ||||
| @@ -209,53 +225,62 @@ def prepare_torch_dataloader(ds_or_db, | |||||
| non_train_batch_size: int = 16) \ | non_train_batch_size: int = 16) \ | ||||
| -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | ||||
| """ | """ | ||||
| prepare_torch_dataloader的功能是将多个dataset同时转为dataloader返回。ds_or_db的类型只能为``[Dataset, DataBundle, | |||||
| Sequence[Dataset], Dict[name, Dataset]]``,具体如下: | |||||
| * 当ds_or_db为Dataset时,prepare_torch_dataloader会将所有的参数除了non_train_batch_size以外来帮你实例化一个 | |||||
| torchDataLoader并返回。 | |||||
| * 当ds_or_db为FastNLP的DataBundle时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader, | |||||
| 当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 | |||||
| 的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader] | |||||
| 的数据返回。 | |||||
| * 当ds_or_db为Dict[name, Dataset]数据类型时,prepare_torch_dataloader会遍历所有的dataset并根据其name实例化不同的torchDataLoader, | |||||
| 当name中包含'train'字符串时,prepare_torch_dataloader默认其为train数据,并将train_batch_size传为其中,其他不包含'train'字符串 | |||||
| 的dataset均使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终根据name:torchDataLoader组成一个Dict[name, torchDataLoader] | |||||
| 的数据返回。 | |||||
| * 当ds_or_db为Sequence[Dataset]数据类型时, prepare_torch_dataloader会将Sequence[0]作为默认的train数据集对待,并使用train_batch_size作为 | |||||
| 其batch_size使用;而Sequence[1:]均视为非train数据集对待,使用non_train_batch_size作为batch_size来实例化torchDataLoader。最终 | |||||
| 将所有torchDataLoader组成Sequence[torchDataLoader]返回。 | |||||
| :param ds_or_db: 传进来的dataset集合或字典或为dataset或DataBundle。其取值只能为``[Dataset, DataBundle, | |||||
| Sequence[Dataset], Dict[name, Dataset]]``. | |||||
| :param shuffle: 是否打乱数据集 | |||||
| :param batch_size: batch 的大小。 | |||||
| :param non_train_batch_size: 如果传入的 ``ds_or_db`` 为 ``Dict`` 或 :class:`~fastNLP.io.DataBundle` 对象,可以通过改参数 | |||||
| 设置名称不为 `train` 的其他 ``dataset`` 的 ``batch_size``。 | |||||
| :param train_sampler: train'数据集使用的sampler, 现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | |||||
| :param non_train_sampler: 非'train'数据使用sampler, 实现了__len__和__iter__方法的实例化对象,其功能是每次返回dataset的一个index,当其不为None时,shuffle参数无效 | |||||
| :param batch_sampler: 实现了__len__和__iter__方法的实例化对象,,其能迭代返回一个list的index数据, index不超过dataset的大小, | |||||
| 当其不为None时,bacth_size,sampler,shuffle均无效。 | |||||
| :param num_workers: 开启子进程的数量,当num_worker=0时不开启多进程 | |||||
| :param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. | |||||
| * ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; | |||||
| 第二点注意的是此时``TorchDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` | |||||
| * ``callate_fn="auto"``时,``TorchDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, | |||||
| 并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, | |||||
| 可以调用``set_ignore``方法忽略某个field。 | |||||
| * ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``TorchDataLoader``会调用传进来的callable函数对 | |||||
| 数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 | |||||
| :param pin_memory: 如果其为True, 那么DataLoader会在返回数据张量之前将其copy到cuda的pin memory中。 | |||||
| :param drop_last: 当``drop_last=True``时,``TorchDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; | |||||
| 若``drop_last=False``, 则什么也不做。 | |||||
| :param timeout: 从子进程的输出队列获取数据的超时值 | |||||
| :param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 | |||||
| :param multiprocessing_context: 多进程的上下文环境 | |||||
| :param generator: 如果其不为None, 将会使用RandomSampler去生成随机的index并且多进程会每个子进程生成一个``base_seed`` | |||||
| :param prefetch_factor: 每个worker提前装载的samples数量。``2``意味着在所有的进程中会有2*num_workers的数据被预取。默认值为2. | |||||
| :param persistent_workers: 如果其为True, dataloader会在迭代完一次dataset后不会所有进程。默认为False | |||||
| ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.TorchDataLoader`。 | |||||
| 根据 ds_or_db 的类型 ``[DataSet, DataBundle,Sequence[Dataset], Dict[name, Dataset]]`` 不同而有不同返回结果, 具体如下: | |||||
| * 当 ds_or_db 为 ``DataSet``时,``prepare_torch_dataloader`` 会将使用的除了 non_train_batch_size 和 non_train_sampler 以外的参数来 | |||||
| 帮你实例化一个 ``TorchDataLoader`` 对象并返回该对象。 详见:class: `~fastNLP.core.dataloaders.TorchDataLoader`。 | |||||
| * 当 ds_or_db 为 :class:`~fastNLP.io.DataBundle` 时,``prepare_torch_dataloader`` 会遍历 ``DataBundle`` 的数据集的 key-value | |||||
| 来创建不同的 ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集, | |||||
| 会将 batch_size 和 sampler 作为参数,其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
| 最终根据 ``key: TorchDataLoader`` 组成 ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
| * 当 ds_or_db 为 ``Dict[str, DataSet]`` 字典类型时, ``prepare_torch_dataloader`` 会遍历 该 dict 的的 key-value 来创建不同的 | |||||
| ``TorchDataLoader`` 对象;当 key 中包含'train'字符串时,``prepare_torch_dataloader`` 默认该 value 为 train 数据集,会将 batch_size 和 sampler 作为参数, | |||||
| 其他 key 不包含 'train' 字符串的数据集则使用 non_train_size 和 non_train_sampler 作为参数。最终根据 ``key: TorchDataLoader`` 组成 | |||||
| ``Dict[key, TorchDataLoader]`` 的字典返回。 | |||||
| * 当 ds_or_db 为 ``Sequence[Dataset]`` 数据类型时, prepare_torch_dataloader 会将 Sequence[0] 的数据集默认为 train 数据集对待, | |||||
| 会将 batch_size 和 sampler 作为参数, 而 Sequence[1:] 数据集均视为非 train 数据集对待,使用 non_train_size 和 non_train_sampler 作为参数。 | |||||
| 最终将所有实例化好的 ``TorchDataLoader`` 组成 ``Sequence[TorchDataLoader]`` 返回。 | |||||
| :param ds_or_db: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。其取值只能为 ``[DataSet, DataBundle, | |||||
| Sequence[DataSet], Dict[str, DataSet]]``. | |||||
| * ds_or_db 为 :class: `~fastNLP.core.dataset.DataSet`,返回值为:class: `~fastNLP.core.dataloaders.TorchDataLoader` | |||||
| * ds_or_db 为 :class: `~fastNLP.io.DataBundle`, 返回值为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
| * ds_or_db 为 ``Sequence[DataSet]`` 序列, 返回值为 ``Sequence[TorchDataLoader]`` 的序列 | |||||
| * ds_or_db 为 ``Dict[str, DataSet]`` 字典, 返回值也为 ``Dict[str, TorchDataLoader]`` 的字典 | |||||
| :param batch_size: 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
| :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | |||||
| :param shuffle: 是否打乱数据集, 默认为 ``False``。 | |||||
| :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
| 默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
| :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | |||||
| 默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
| :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | |||||
| dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
| :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | |||||
| 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
| :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | |||||
| * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | |||||
| ``TorchDataLoader`` 调用默认的 torch 框架的 ``DataLoader`` 自带的 `default_collate_fn` 作为 callate_fn 的默认值, 其无法处理 | |||||
| :class:`~fastNLP.core.dataset.DataSet` 的dataset对象。 | |||||
| * callate_fn 为 ``'auto'`` 时,`TorchDataLoader`` 使用 :class:`~fastNLP.core.collators.Collator` 作为 collate_fn 的默认值。 | |||||
| 此时可以配套使用 ``TorchDataLoader`` 的 ``set_pad`` 和 ``set_ignore`` 方法来设置 pad_val 或 忽略某个 field 的检测。 | |||||
| * `collate_fn 为 ``Callable`` 时, 该 Callable 函数应当接受一个 batch 参数作为输入, batch 是一个 List 对象且 List 中的每一条数据都是 | |||||
| dataset 的一条数据;该 Callable 函数还应当返回一个对象。 | |||||
| :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | |||||
| :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | |||||
| 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
| :param timeout: 子进程的输出队列获取数据的超时值 | |||||
| :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | |||||
| :param multiprocessing_context: 多进程的上下文环境 | |||||
| :param generator: 如果其不为 ``None``, 将会使用 RandomSampler 去生成随机的 index 且会为每个子进程生成一个``base_seed`` | |||||
| :param prefetch_factor: 每个 worker 提前装载的 samples 数量。``2``意味着在所有的进程中会有 2*num_workers 的数据被预取。默认值为 ``2`` . | |||||
| :param persistent_workers: 如果其为 ``True``, ``TorchDataLoader`` 在迭代完一次 dataset 后不会关闭所有进程。默认为 ``False`` | |||||
| """ | """ | ||||
| from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||