Browse Source

修改test_mixdataloader中的工具函数名,防止被pytest执行

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
da0b747b30
1 changed files with 28 additions and 28 deletions
  1. +28
    -28
      tests/core/dataloaders/torch_dataloader/test_mixdataloader.py

+ 28
- 28
tests/core/dataloaders/torch_dataloader/test_mixdataloader.py View File

@@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10]
d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})


def test_pad_val(tensor, val=0):
def _test_pad_val(tensor, val=0):
if isinstance(tensor, torch.Tensor):
tensor = tensor.tolist()
for item in tensor:
@@ -45,7 +45,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# collate_fn = Callable
def collate_batch(batch):
@@ -74,13 +74,13 @@ class TestMixDataLoader:
dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
for idx, batch in enumerate(dl2):
if idx == 0:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'].shape == torch.Size([16, 4])
if idx == 1:
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'].shape == torch.Size([16, 3])
if idx > 1:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'].shape == torch.Size([16, 4])

# sampler 为 str
@@ -101,7 +101,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

for idx, batch in enumerate(dl4):
if idx == 0:
@@ -118,7 +118,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
@@ -137,7 +137,7 @@ class TestMixDataLoader:
if idx > 1:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)

# ds_ratio 为 'truncate_to_least'
dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
@@ -154,7 +154,7 @@ class TestMixDataLoader:
# d3
assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx > 2:
raise ValueError(f"ds_ratio: 'truncate_to_least' error")

@@ -170,7 +170,7 @@ class TestMixDataLoader:
if 36 <= idx < 54:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 54:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -187,7 +187,7 @@ class TestMixDataLoader:
if 4 <= idx < 41:
# d3
assert batch['x'].shape == torch.Size([16, 4])
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 41:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -201,7 +201,7 @@ class TestMixDataLoader:
# d3
assert batch['x'].shape == torch.Size([16, 4])

assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 19:
raise ValueError(f"ds_ratio: 'pad_to_most' error")

@@ -209,7 +209,7 @@ class TestMixDataLoader:
datasets = {'d1': d1, 'd2': d2, 'd3': d3}
dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
for idx, batch in enumerate(dl):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")

@@ -224,7 +224,7 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
for idx, batch in enumerate(dl1):
assert isinstance(batch['x'], list)
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")

@@ -237,12 +237,12 @@ class TestMixDataLoader:
# sampler 为 str
dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
for idx, batch in enumerate(dl3):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
for idx, batch in enumerate(dl4):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# sampler 为 Dict
@@ -251,7 +251,7 @@ class TestMixDataLoader:
'd3': RandomSampler(d3)}
dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
for idx, batch in enumerate(dl5):
assert test_pad_val(batch['x'], val=0)
assert _test_pad_val(batch['x'], val=0)
if idx >= 22:
raise ValueError(f"out of range")
# ds_ratio 为 'truncate_to_least'
@@ -333,7 +333,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# collate_fn = Callable
def collate_batch(batch):
@@ -361,16 +361,16 @@ class TestMixDataLoader:
dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
for idx, batch in enumerate(dl1):
if idx == 0 or idx == 3:
assert test_pad_val(batch['x'], val=-1)
assert _test_pad_val(batch['x'], val=-1)
assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
assert batch['x'].shape[1] == 4
elif idx == 1 or idx == 4:
# d2
assert test_pad_val(batch['x'], val=-2)
assert _test_pad_val(batch['x'], val=-2)
assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
assert batch['x'].shape[1] == 3
elif idx == 2 or 4 < idx <= 20:
assert test_pad_val(batch['x'], val=-3)
assert _test_pad_val(batch['x'], val=-3)
assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
assert batch['x'].shape[1] == 4
if idx > 20:
@@ -392,7 +392,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
for idx, batch in enumerate(dl3):
if idx == 0 or idx == 3:
assert batch['x'].shape[1] == 4
@@ -403,7 +403,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)
# sampler 为 Dict
samplers = {'d1': SequentialSampler(d1),
'd2': SequentialSampler(d2),
@@ -421,7 +421,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 20:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 'truncate_to_least'
dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
@@ -438,7 +438,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 5:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 'pad_to_most'
dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
@@ -457,7 +457,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx >= 51:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

# ds_ratio 为 Dict[str, float]
ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
@@ -475,7 +475,7 @@ class TestMixDataLoader:
assert batch['x'].shape[1] == 4
if idx > 39:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
@@ -493,4 +493,4 @@ class TestMixDataLoader:

if idx > 18:
raise ValueError(f"out of range")
test_pad_val(batch['x'], val=0)
_test_pad_val(batch['x'], val=0)

Loading…
Cancel
Save