@@ -47,6 +47,7 @@ from fastNLP.core.collators.collator import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.core.dataset import DataSet as FDataSet | from fastNLP.core.dataset import DataSet as FDataSet | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | ||||
from ..utils import _match_param | |||||
class _PaddleDataset(Dataset): | class _PaddleDataset(Dataset): | ||||
@@ -154,14 +155,17 @@ class PaddleDataLoader(DataLoader): | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||||
return_list=return_list, batch_sampler=batch_sampler, | |||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
collate_fn=collate_fn, num_workers=num_workers, | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
dl_kwargs = _match_param(PaddleDataLoader.__init__, DataLoader.__init__, DataLoader.__name__) | |||||
if dl_kwargs is None: | |||||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||||
return_list=return_list, batch_sampler=batch_sampler, | |||||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||||
collate_fn=collate_fn, num_workers=num_workers, | |||||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | |||||
persistent_workers=persistent_workers) | |||||
else: | |||||
super().__init__(**dl_kwargs) | |||||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | # _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | ||||
# if collate_fn is not None: | # if collate_fn is not None: | ||||
# _collate_fn.add_collator(collate_fn) | # _collate_fn.add_collator(collate_fn) | ||||
@@ -11,6 +11,7 @@ from fastNLP.core.collators import Collator | |||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | ||||
from ..utils import _match_param | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
from torch.utils.data import DataLoader, Sampler | from torch.utils.data import DataLoader, Sampler | ||||
@@ -96,12 +97,16 @@ class TorchDataLoader(DataLoader): | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
dl_kwargs = _match_param(TorchDataLoader.__init__, DataLoader.__init__, fn_name=DataLoader.__name__) | |||||
if dl_kwargs is None: | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
else: | |||||
super().__init__(**dl_kwargs) | |||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
@@ -1,4 +1,9 @@ | |||||
from typing import Callable | from typing import Callable | ||||
import inspect | |||||
import ast | |||||
from ..log import logger | |||||
from ..utils.cache_results import get_func_calls, truncate_start_blanks | |||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | "indice_collate_wrapper" | ||||
] | ] | ||||
@@ -25,6 +30,72 @@ def indice_collate_wrapper(func:Callable): | |||||
return _indice_collate_wrapper | return _indice_collate_wrapper | ||||
def _match_param(fun, call_fn:Callable, fn_name:str=None): | |||||
""" | |||||
在调用 _match_param 的函数(就是 fun )中会调用 call_fn 这个函数。由于 fun 中支持的函数比 call_fn 更多,例如低版本的 | |||||
:class:`~.fastNLP.TorchDataLoader` 中支持的参数,在torch 1.6 版本的 DataLoader 就不支持,但在高版本的 torch 中是支持的 | |||||
因此,这里需要根据当前版本的 DataLoader 判定出适合传入 DataLoader 进行初始化的参数,并且在不支持但又被设置的参数上进行 | |||||
warning 。 | |||||
:param fun: 调用函数本身 | |||||
:param call_fn: | |||||
:param fn_name: 方便报错的用的函数 | |||||
:return: | |||||
""" | |||||
try: | |||||
if fn_name is None: | |||||
try: | |||||
fn_name = call_fn.__name__ | |||||
except: | |||||
fn_name = str(call_fn) | |||||
last_frame = inspect.currentframe().f_back | |||||
# 调用 _match_param 的函数名称,获取默认的参数值 | |||||
fun_default_params = {} | |||||
fun_parameters = inspect.signature(fun) | |||||
for name, fun_param in fun_parameters.parameters.items(): | |||||
if fun_param.default is not fun_param.empty: | |||||
fun_default_params[name] = fun_param.default | |||||
# 获取实际传入的参数值 | |||||
param_names, args_name, kwargs_name, values = inspect.getargvalues(last_frame) | |||||
if args_name is not None: | |||||
raise RuntimeError("Function does not support positional arguments, such as: fun(*args).") | |||||
kwargs = values.get(kwargs_name, {}) | |||||
for param in param_names: | |||||
if param not in values: | |||||
value = fun_default_params.get(param) | |||||
else: | |||||
value = values[param] | |||||
kwargs[param] = value | |||||
# 根据需要实际需要调用的 call_fn 的参数进行匹配 | |||||
call_fn_parameters = inspect.signature(call_fn) | |||||
call_fn_kwargs = {} | |||||
has_kwargs = False | |||||
for name, param in call_fn_parameters.parameters.items(): | |||||
if name == 'self': | |||||
continue | |||||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.KEYWORD_ONLY): # 最前面的 args | |||||
call_fn_kwargs[name] = param.default | |||||
if param.kind == param.VAR_KEYWORD: | |||||
has_kwargs = True | |||||
# 组装得到最终的参数 | |||||
call_kwargs = {} | |||||
for name, value in kwargs.items(): | |||||
if name in call_fn_kwargs or has_kwargs: # 如果存在在里面,或者包含了 kwargs 就直接运行 | |||||
call_kwargs[name] = value | |||||
# 如果不在需要调用的函数里面,同时又是非默认值 | |||||
elif name not in call_fn_kwargs and name in fun_default_params and fun_default_params[name]!=value: | |||||
logger.rank_zero_warning(f"Parameter:{name} is not supported for {fn_name}.") | |||||
return call_kwargs | |||||
except BaseException as e: | |||||
logger.debug(f"Exception happens when match parameters for {fn_name}: {e}") | |||||
return None | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
def demo(*args, **kwargs): | def demo(*args, **kwargs): | ||||
pass | pass | ||||
@@ -0,0 +1,39 @@ | |||||
import pytest | |||||
from fastNLP.core.dataloaders.utils import _match_param | |||||
from fastNLP import logger | |||||
from tests.helpers.utils import recover_logger, Capturing | |||||
def demo(): | |||||
pass | |||||
def test_no_args(): | |||||
def f(*args, a, b, **kwarg): | |||||
c = 100 | |||||
call_kwargs = _match_param(f, demo) | |||||
with pytest.raises(RuntimeError): | |||||
f(a=1, b=2) | |||||
def f(a, *args, b, **kwarg): | |||||
c = 100 | |||||
call_kwargs = _match_param(f, demo) | |||||
with pytest.raises(RuntimeError): | |||||
f(a=1, b=2) | |||||
@recover_logger | |||||
def test_warning(): | |||||
logger.set_stdout('raw') | |||||
def f1(a, b): | |||||
return 1 | |||||
def f2(a, b, c=2): | |||||
kwargs = _match_param(f2, f1) | |||||
return f1(*kwargs) | |||||
with Capturing() as out: | |||||
f2(a=1, b=2, c=3) | |||||
assert 'Parameter:c' in out[0] # 传入了需要 warning | |||||
assert f2(1, 2) == 1 |
@@ -5,6 +5,9 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core import Trainer | from fastNLP.core import Trainer | ||||
from pkg_resources import parse_version | |||||
from tests.helpers.utils import Capturing, recover_logger | |||||
from fastNLP import logger | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -128,3 +131,33 @@ class TestFdl: | |||||
dl = DataLoader(MyDatset(), collate_fn=collate_batch) | dl = DataLoader(MyDatset(), collate_fn=collate_batch) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | print(batch) | ||||
@recover_logger | |||||
def test_version_16(self): | |||||
if parse_version(torch.__version__) >= parse_version('1.7'): | |||||
pytest.skip("Torch version larger than 1.7") | |||||
logger.set_stdout() | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
with Capturing() as out: | |||||
dl = TorchDataLoader(ds, prefetch_factor=3, shuffle=False) | |||||
for idx, batch in enumerate(dl): | |||||
assert len(batch['x'])==1 | |||||
assert batch['x'][0].tolist() == ds[idx]['x'] | |||||
assert 'Parameter:prefetch_factor' in out[0] | |||||
@recover_logger | |||||
def test_version_111(self): | |||||
if parse_version(torch.__version__) <= parse_version('1.7'): | |||||
pytest.skip("Torch version smaller than 1.7") | |||||
logger.set_stdout() | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
with Capturing() as out: | |||||
dl = TorchDataLoader(ds, num_workers=2, prefetch_factor=3, shuffle=False) | |||||
for idx, batch in enumerate(dl): | |||||
assert len(batch['x'])==1 | |||||
assert batch['x'][0].tolist() == ds[idx]['x'] | |||||
assert 'Parameter:prefetch_factor' not in out[0] | |||||