diff --git a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py index 17c151ea..bf9e3d9e 100644 --- a/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py +++ b/tests/core/dataloaders/torch_dataloader/test_mixdataloader.py @@ -17,7 +17,7 @@ d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100}) -def test_pad_val(tensor, val=0): +def _test_pad_val(tensor, val=0): if isinstance(tensor, torch.Tensor): tensor = tensor.tolist() for item in tensor: @@ -45,7 +45,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -74,13 +74,13 @@ class TestMixDataLoader: dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True) for idx, batch in enumerate(dl2): if idx == 0: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'].shape == torch.Size([16, 4]) if idx == 1: - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'].shape == torch.Size([16, 3]) if idx > 1: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'].shape == torch.Size([16, 4]) # sampler 为 str @@ -101,7 +101,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl4): if idx == 0: @@ -118,7 +118,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), @@ -137,7 +137,7 @@ class TestMixDataLoader: if idx > 1: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True) @@ -154,7 +154,7 @@ class TestMixDataLoader: # d3 assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]] assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx > 2: raise ValueError(f"ds_ratio: 'truncate_to_least' error") @@ -170,7 +170,7 @@ class TestMixDataLoader: if 36 <= idx < 54: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 54: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -187,7 +187,7 @@ class TestMixDataLoader: if 4 <= idx < 41: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 41: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -201,7 +201,7 @@ class TestMixDataLoader: # d3 assert batch['x'].shape == torch.Size([16, 4]) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 19: raise ValueError(f"ds_ratio: 'pad_to_most' error") @@ -209,7 +209,7 @@ class TestMixDataLoader: datasets = {'d1': d1, 'd2': d2, 'd3': d3} dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True) for idx, batch in enumerate(dl): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -224,7 +224,7 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True) for idx, batch in enumerate(dl1): assert isinstance(batch['x'], list) - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") @@ -237,12 +237,12 @@ class TestMixDataLoader: # sampler 为 str dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True) for idx, batch in enumerate(dl3): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True) for idx, batch in enumerate(dl4): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # sampler 为 Dict @@ -251,7 +251,7 @@ class TestMixDataLoader: 'd3': RandomSampler(d3)} dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True) for idx, batch in enumerate(dl5): - assert test_pad_val(batch['x'], val=0) + assert _test_pad_val(batch['x'], val=0) if idx >= 22: raise ValueError(f"out of range") # ds_ratio 为 'truncate_to_least' @@ -333,7 +333,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # collate_fn = Callable def collate_batch(batch): @@ -361,16 +361,16 @@ class TestMixDataLoader: dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18) for idx, batch in enumerate(dl1): if idx == 0 or idx == 3: - assert test_pad_val(batch['x'], val=-1) + assert _test_pad_val(batch['x'], val=-1) assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]] assert batch['x'].shape[1] == 4 elif idx == 1 or idx == 4: # d2 - assert test_pad_val(batch['x'], val=-2) + assert _test_pad_val(batch['x'], val=-2) assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]] assert batch['x'].shape[1] == 3 elif idx == 2 or 4 < idx <= 20: - assert test_pad_val(batch['x'], val=-3) + assert _test_pad_val(batch['x'], val=-3) assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]] assert batch['x'].shape[1] == 4 if idx > 20: @@ -392,7 +392,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) for idx, batch in enumerate(dl3): if idx == 0 or idx == 3: assert batch['x'].shape[1] == 4 @@ -403,7 +403,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # sampler 为 Dict samplers = {'d1': SequentialSampler(d1), 'd2': SequentialSampler(d2), @@ -421,7 +421,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 20: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'truncate_to_least' dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18) @@ -438,7 +438,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 5: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 'pad_to_most' dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18) @@ -457,7 +457,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx >= 51: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) # ds_ratio 为 Dict[str, float] ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0} @@ -475,7 +475,7 @@ class TestMixDataLoader: assert batch['x'].shape[1] == 4 if idx > 39: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) + _test_pad_val(batch['x'], val=0) ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0} dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18) @@ -493,4 +493,4 @@ class TestMixDataLoader: if idx > 18: raise ValueError(f"out of range") - test_pad_val(batch['x'], val=0) \ No newline at end of file + _test_pad_val(batch['x'], val=0) \ No newline at end of file