Browse Source

修改torch fdl

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
4be26c5620
2 changed files with 7 additions and 3 deletions
  1. +2
    -0
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  2. +5
    -3
      tests/core/dataloaders/torch_dataloader/test_fdl.py

+ 2
- 0
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -98,6 +98,7 @@ class TorchDataLoader(DataLoader):
def __getattr__(self, item): def __getattr__(self, item):
""" """
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法

:param item: :param item:
:return: :return:
""" """
@@ -119,6 +120,7 @@ class TorchDataLoader(DataLoader):
""" """
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数
当val=None时,意味着给定的field_names都不需要尝试padding 当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names: :param field_names:
:param val: padding值,默认为0 :param val: padding值,默认为0
:return: :return:


+ 5
- 3
tests/core/dataloaders/torch_dataloader/test_fdl.py View File

@@ -21,11 +21,12 @@ class TestFdl:
ds.set_pad_val("x", val=-1) ds.set_pad_val("x", val=-1)
fdl = TorchDataLoader(ds, batch_size=3) fdl = TorchDataLoader(ds, batch_size=3)
fdl.set_input("x", "y") fdl.set_input("x", "y")
fdl.set_pad_val("x", val=None)
for batch in fdl: for batch in fdl:
print(batch) print(batch)
fdl.set_pad_val("x", val=-2)
for batch in fdl:
print(batch)
# fdl.set_pad_val("x", val=-2)
# for batch in fdl:
# print(batch)


def test_add_collator(self): def test_add_collator(self):
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
@@ -38,6 +39,7 @@ class TestFdl:


fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True)
fdl.set_input("x", "y") fdl.set_input("x", "y")
# fdl.set_pad_val("x", val=None)
fdl.add_collator(collate_fn) fdl.add_collator(collate_fn)
for batch in fdl: for batch in fdl:
print(batch) print(batch)


Loading…
Cancel
Save