diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 02b56cd7..8800be8e 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -48,6 +48,7 @@ __all__ = [ 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', 'prepare_torch_dataloader', + "prepare_dataloader", # dataset 'DataSet', diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 10229f66..a51406af 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -32,7 +32,7 @@ class CheckpointCallback(Callback): model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。 - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index 1e867866..ad1c95cd 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -12,7 +12,7 @@ class EarlyStopCallback(HasMonitorCallback): def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): """ - :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 2d1affd2..d934e24a 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -34,7 +34,7 @@ class ResultsMonitor: """ 可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: monitor 是否时越大越好 @@ -171,7 +171,7 @@ class HasMonitorCallback(ResultsMonitor, Callback): 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: monitor 是否时越大越好 @@ -209,7 +209,7 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): """ 当监控的 monitor 结果更好的时候,调用 execute_fn 函数。 - :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: monitor 是否时越大越好 diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 227c80c4..362716ef 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -21,7 +21,7 @@ class LoadBestModelCallback(HasMonitorCallback): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 :param larger_better: 该 metric 值是否是越大越好。 diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 33415b7a..896f8865 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -37,7 +37,7 @@ class MoreEvaluateCallback(HasMonitorCallback): 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 - 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 + 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的monitor 值请返回 None 。 :param watch_monitor_larger_better: watch_monitor 是否越大越好。 diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 335345e0..2ce177e2 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -46,7 +46,7 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 - 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 相关的 monitor 值请返回 None 。 :param larger_better: 是否是 monitor 的结果越大越好。 @@ -141,7 +141,7 @@ class RawTextCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 - 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 相关的 monitor 值请返回 None 。 :param larger_better: 是否是monitor的结果越大越好。 diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 09843511..aba2ff63 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -183,7 +183,7 @@ class TopkSaver(ResultsMonitor, Saver): :param topk: 保存 topk 多少的模型,-1 为保存所有模型;0 为都不保存;大于 0 的数为保存 topk 个。 :param monitor: 监控哪个指标判断是否是 topk 的。监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 - 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, + 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数, 接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请 返回 None 。 :param larger_better: 该 monitor 是否越大越好。 diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 2e4f23b8..aef9de4c 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -6,19 +6,20 @@ from typing import List, Union, Dict, Callable, Sequence, Mapping import os import sys import inspect +import re from fastNLP.core.log import logger from .padders.get_padder import get_padder +from ...envs import SUPPORT_BACKENDS -import re from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \ NestedMappingPackerUnpacker sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] -CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend - +# 由于 jittor DataLoader 存在自动的 to_jittor 的转换,所以只需要 collate 成为 numpy 就行 +AUTO_BACKEND_MAPPING = {'jittor': 'numpy'} def _get_backend() -> str: """ @@ -40,7 +41,7 @@ def _get_backend() -> str: catch_backend = [] try: file = module.__file__ - for backend in CHECK_BACKEND: + for backend in SUPPORT_BACKENDS: if f'{os.sep}site-packages{os.sep}{backend}' in file: catch_backend = [backend, file] except: @@ -62,10 +63,10 @@ def _get_backend() -> str: break if len(catch_backend): logger.debug(f"Find a file named:{catch_backend[1]} from stack contains backend:{catch_backend[0]}.") - return catch_backend[0] + return AUTO_BACKEND_MAPPING.get(catch_backend[0], catch_backend[0]) # 方式 (2) - for backend in CHECK_BACKEND: + for backend in SUPPORT_BACKENDS: if backend in sys.modules: logger.debug(f"sys.modules contains backend:{backend}.") return backend diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 7e91ec42..c4dbdadc 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -30,7 +30,8 @@ if _NEED_IMPORT_PADDLE: } from .padder import Padder -from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class +from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, is_numpy_generic_class, \ + get_padded_numpy_array from .exceptions import * @@ -54,7 +55,6 @@ def is_paddle_dtype_str(dtype): return False - def _get_dtype(ele_dtype, dtype, class_name): if not (ele_dtype is None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " @@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], paddle.Tensor): - batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] + batch_field = [np.array(field.tolist()) for field in batch_field] else: if dtype is None: dtype = batch_field[0].dtype @@ -141,46 +141,14 @@ class PaddleTensorPadder(Padder): shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] - tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) + array = np.full(max_shape, fill_value=pad_val) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) - tensor[slices] = field + array[slices] = field + tensor = paddle.to_tensor(array, dtype=dtype) return tensor -def fill_tensor(batch_field, padded_batch, dtype): - """ - 将 batch_field 中的值填入到 tensor 中。 - - :param batch_field: 需要填充进入 array 中的内容 - :param padded_batch: 待填充的 tensor - :param dtype: 数据的类别 - - :return: - """ - if padded_batch.ndim == 2: - for i, content_i in enumerate(batch_field): - padded_batch[i, :len(content_i)] = paddle.to_tensor(content_i, dtype=dtype) - elif padded_batch.ndim == 3: - for i, content_i in enumerate(batch_field): - for j, content_ii in enumerate(content_i): - padded_batch[i, j, :len(content_ii)] = paddle.to_tensor(content_ii, dtype=dtype) - elif padded_batch.ndim == 4: - try: # 应该是图像,所以直接应该就 ok 了。 - padded_batch = np.array(batch_field) - except: - for i, content_i in enumerate(batch_field): - for j, content_ii in enumerate(content_i): - for k, content_iii in enumerate(content_ii): - padded_batch[i, j, k, :len(content_iii)] = paddle.to_tensor(content_iii, dtype=dtype) - elif padded_batch.ndim == 1: - padded_batch[:] = paddle.to_tensor(batch_field, dtype=dtype) - else: - raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " - "report.") - return padded_batch - - def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): """ 例如: @@ -192,7 +160,6 @@ def get_padded_paddle_tensor(batch_field, dtype=None, pad_val=0): :param pad_val: pad 的 value :return: """ - shapes = get_shape(batch_field) - tensor = paddle.to_tensor(np.full(shape=shapes, fill_value=pad_val), dtype=dtype) - tensor = fill_tensor(batch_field, tensor, dtype=dtype) + array = get_padded_numpy_array(batch_field=batch_field, dtype=None, pad_val=pad_val) + tensor = paddle.to_tensor(array, dtype=dtype) return tensor diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index d82577f8..abd70644 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -161,6 +161,7 @@ class Evaluator: self.reset() self.driver.barrier() except BaseException as e: + self.driver.on_exception() raise e finally: self.finally_progress_bar() diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index f9ab144c..59a4501b 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -125,7 +125,7 @@ class Trainer(TrainerEventTrigger): :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 - 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 :param larger_better: monitor 的值是否是越大越好。 @@ -372,6 +372,14 @@ class Trainer(TrainerEventTrigger): self.on_exception(e) if not catch_KeyboardInterrupt: raise e + except RuntimeError as e: + if 'torch' in self.driver_name.lower(): # 如果是 torch ,需要检测一下 find_unused_parameters + if 'find_unused_parameters' in e.args[0]: + logger.error("You may need to pass `torch_ddp_kwargs={'find_unused_parameters': True}` in the " + "Trainer initialization to avoid this error.") + self.driver.on_exception() + self.on_exception(e) + raise e except BaseException as e: self.driver.on_exception() self.on_exception(e) diff --git a/fastNLP/core/dataloaders/__init__.py b/fastNLP/core/dataloaders/__init__.py index 40dd7b1c..976788d9 100644 --- a/fastNLP/core/dataloaders/__init__.py +++ b/fastNLP/core/dataloaders/__init__.py @@ -5,10 +5,13 @@ __all__ = [ 'JittorDataLoader', 'prepare_jittor_dataloader', 'prepare_paddle_dataloader', - 'prepare_torch_dataloader' + 'prepare_torch_dataloader', + + "prepare_dataloader" ] from .mix_dataloader import MixDataLoader from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader +from .prepare_dataloader import prepare_dataloader \ No newline at end of file diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 434fe7f9..8ecd2d87 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -4,6 +4,7 @@ __all__ = [ ] from typing import Callable, Optional, List, Union +from copy import deepcopy from fastNLP.envs.imports import _NEED_IMPORT_JITTOR @@ -75,10 +76,12 @@ class JittorDataLoader: if isinstance(collate_fn, str): if collate_fn == "auto": if isinstance(self.dataset.dataset, FDataSet): - self.collate_fn = self.dataset.dataset.collator - self.collate_fn.set_backend(backend="jittor") + self.collate_fn = deepcopy(self.dataset.dataset.collator) + # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var + self.collate_fn.set_backend(backend="numpy") else: - self.collate_fn = Collator(backend="jittor") + # jittor 比较特殊,只需要保证返回 numpy.array, 其Dataloader会转为jt.var + self.collate_fn = Collator(backend="numpy") else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") elif isinstance(collate_fn, Callable): diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 977197f6..5c5e3bef 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -4,6 +4,7 @@ __all__ = [ ] from typing import Callable, List, Optional, Union, Dict, Sequence +from copy import deepcopy from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -68,7 +69,7 @@ class PaddleDataLoader(DataLoader): if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, FDataSet): - collate_fn = dataset.dataset.collator + collate_fn = deepcopy(dataset.dataset.collator) collate_fn.set_backend(backend="paddle") else: collate_fn = Collator(backend="paddle") diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py new file mode 100644 index 00000000..193ec384 --- /dev/null +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -0,0 +1,114 @@ +__all__ = [ + 'prepare_dataloader' +] + +from typing import Union, Callable +import os +import sys + +from ..samplers import RandomBatchSampler, RandomSampler +from .torch_dataloader import prepare_torch_dataloader +from .paddle_dataloader import prepare_paddle_dataloader +from .jittor_dataloader import prepare_jittor_dataloader +from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS, _module_available +from ..log import logger + + +def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, + collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, + seed: int = 0, backend: str = 'auto'): + """ + 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 + 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 + :func:`~fastNLP.prepare_torch_dataloader` 或 :func:`~fastNLP.prepare_paddle_dataloader` 等。 + + :param dataset: 实现 __getitem__() 和 __len__() 的对象;或这种对象的序列;或字典。 + + * 为单个数据集对象时 + 返回一个 DataLoader 。 + * 为数据集对象序列时 + 返回一个序列的 DataLoader 。 + * 为字典型 或 :class:`~fastNLP.io.DataBundle` 数据时,返回 `Dict` 类型的数据。 + 返回一个字典 。 + + :param batch_size: 批次大小。 + :param shuffle: 是否打乱数据集。 + :param drop_last: 当最后一个 batch 不足 batch_size 数量的是否,是否丢弃。 + :param collate_fn: 用于处理一个 batch 的函数,一般包括 padding 和转为 tensor。有以下三种取值: + + * 为 ``auto`` 时 + 使用 :class:`~fastNLP.Collator` 进行 padding 和 转tensor 。 + * 为 ``Callable`` 时 + 应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 + * 为 ``None`` 时 + 使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 + :param num_workers: 使用多少进程进行数据的 fetch 。 + :param seed: 使用的随机数种子。 + :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 + + * 为 ``auto`` 时 + 首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 ``sys.modules`` 中已经 import 的 + ``backend`` 进行判定。如果以上均无法判定,则报错。如果找到了 ``backend`` ,则按照下述的方式处理。 + * 为 ``torch`` 时 + 使用 :func:`~fastNLP.prepare_torch_dataloader` 。 + * 为 ``paddle`` 时 + 使用 :func:`~fastNLP.prepare_paddle_dataloader` 。 + * 为 ``jittor`` 时 + 使用 :func:`~fastNLP.prepare_jittor_dataloader` 。 + + :return + """ + if backend == 'auto': + backend = _get_backend() + if backend == 'torch': + batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last, seed=seed) + return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, + num_workers=num_workers, shuffle=False, sampler=None) + elif backend == 'paddle': + batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last, seed=seed) + return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, + num_workers=num_workers) + elif backend == 'jittor': + sampler = RandomSampler(dataset=dataset, shuffle=shuffle, seed=seed) + prepare_jittor_dataloader(ds_or_db=dataset, sampler=sampler, collate_fn=collate_fn, + num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, + drop_last=drop_last) + else: + raise ValueError(f"Currently we do not support backend:{backend}.") + + +def _check_module(module): + """ + 检查该 module 是否含有 某个 backend 的特征 + + :param module: module 对象 + :return: + """ + try: + file = module.__file__ + for backend in SUPPORT_BACKENDS: + if f'{os.sep}site-packages{os.sep}{backend}' in file: + return backend + except: + pass + return None + + +def _get_backend(): + if os.environ.get(FASTNLP_BACKEND, None) != None: + backend = os.environ.get(FASTNLP_BACKEND) + logger.debug(f"Get Dataloader backend:{backend} from os.environ") + else: + available_backends = set() + for module in sys.modules.values(): + _backend = _check_module(module) + if _backend: + available_backends.add(_backend) + if len(available_backends) == 1: + backend = available_backends.pop() + logger.debug(f"Get Dataloader backend:{backend} from sys.modules.") + else: + raise RuntimeError("Fail to detect dataloader backend automatically, please set it manually.") + return backend \ No newline at end of file diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 48fee045..6a9e4af9 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -4,7 +4,7 @@ __all__ = [ ] from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List -import inspect +from copy import deepcopy from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator @@ -84,7 +84,7 @@ class TorchDataLoader(DataLoader): if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset - collate_fn = dataset.dataset.collator + collate_fn = deepcopy(dataset.dataset.collator) collate_fn.set_backend(backend="torch") else: collate_fn = Collator(backend="torch") @@ -178,8 +178,8 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], - batch_size: int = 16, - shuffle: bool = True, + batch_size: int = 1, + shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', @@ -250,26 +250,15 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping elif isinstance(ds_or_db, Sequence): dl_bundle = [] for idx, ds in enumerate(ds_or_db): - if idx == 0: - dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, - num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, - drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - ) - ) - else: - dl_bundle.append( - TorchDataLoader(dataset=ds, batch_size=batch_size, - shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, - num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, - drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, - ) - ) + dl_bundle.append( + TorchDataLoader(dataset=ds, batch_size=batch_size, + shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, + num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + ) + ) return dl_bundle elif isinstance(ds_or_db, Mapping): diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 6cc49278..b94e7bde 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -285,7 +285,7 @@ class PaddleFleetDriver(PaddleDriver): self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) reset_seed() - logger.info(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") + logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") if not parallel_helper._is_parallel_ctx_initialized(): fleet.init(self.role_maker, self.is_collective, self.strategy) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 2f976f18..85491b2e 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -251,7 +251,7 @@ class TorchDDPDriver(TorchDriver): self.world_size = int(os.environ.get("WORLD_SIZE")) self.global_rank = int(os.environ.get("RANK")) reset_seed() - logger.info(f"World size:{self.world_size}, Global rank:{self.global_rank}") + logger.info(f"World size: {self.world_size}, Global rank: {self.global_rank}") if not dist.is_initialized(): dist.init_process_group( diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index df8b48f7..1ca83c09 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -61,7 +61,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi elif device is not None and not isinstance(device, torch.device): raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") - if driver == "torch": + if driver == "torch": # single, ddp, 直接启动。 if not isinstance(device, List): return TorchSingleDriver(model, device, **kwargs) else: diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index edb41032..b07d8b82 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -22,6 +22,8 @@ import numpy as np from pathlib import Path from fastNLP.core.log import logger +from ...envs import SUPPORT_BACKENDS + __all__ = [ 'get_fn_arg_names', diff --git a/tests/core/dataloaders/test_prepare_dataloader.py b/tests/core/dataloaders/test_prepare_dataloader.py new file mode 100644 index 00000000..223b7880 --- /dev/null +++ b/tests/core/dataloaders/test_prepare_dataloader.py @@ -0,0 +1,13 @@ +import pytest + +from fastNLP import prepare_dataloader +from fastNLP import DataSet + + +@pytest.mark.torch +def test_torch(): + import torch + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = prepare_dataloader(ds, batch_size=2, shuffle=True) + for batch in dl: + assert isinstance(batch['x'], torch.Tensor) \ No newline at end of file