diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index d56dbac9..7f0d1bb6 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -98,6 +98,7 @@ class TorchDataLoader(DataLoader): def __getattr__(self, item): """ 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 + :param item: :return: """ @@ -119,6 +120,7 @@ class TorchDataLoader(DataLoader): """ 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 当val=None时,意味着给定的field_names都不需要尝试padding + :param field_names: :param val: padding值,默认为0 :return: diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index baa3781a..7c1352aa 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -21,11 +21,12 @@ class TestFdl: ds.set_pad_val("x", val=-1) fdl = TorchDataLoader(ds, batch_size=3) fdl.set_input("x", "y") + fdl.set_pad_val("x", val=None) for batch in fdl: print(batch) - fdl.set_pad_val("x", val=-2) - for batch in fdl: - print(batch) + # fdl.set_pad_val("x", val=-2) + # for batch in fdl: + # print(batch) def test_add_collator(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) @@ -38,6 +39,7 @@ class TestFdl: fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) fdl.set_input("x", "y") + # fdl.set_pad_val("x", val=None) fdl.add_collator(collate_fn) for batch in fdl: print(batch)