|
|
@@ -11,7 +11,7 @@ class Test_WrapDataLoader: |
|
|
|
for sanity_batches in all_sanity_batches: |
|
|
|
data = NormalIterator(num_of_data=1000) |
|
|
|
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) |
|
|
|
dataloader = iter(wrapper(dataloader=data)) |
|
|
|
dataloader = iter(wrapper) |
|
|
|
mark = 0 |
|
|
|
while True: |
|
|
|
try: |
|
|
@@ -32,8 +32,7 @@ class Test_WrapDataLoader: |
|
|
|
dataset = TorchNormalDataset(num_of_data=1000) |
|
|
|
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) |
|
|
|
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) |
|
|
|
dataloader = wrapper(dataloader) |
|
|
|
dataloader = iter(dataloader) |
|
|
|
dataloader = iter(wrapper) |
|
|
|
all_supposed_running_data_num = 0 |
|
|
|
while True: |
|
|
|
try: |
|
|
@@ -55,6 +54,5 @@ class Test_WrapDataLoader: |
|
|
|
dataset = TorchNormalDataset(num_of_data=1000) |
|
|
|
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) |
|
|
|
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) |
|
|
|
dataloader = wrapper(dataloader) |
|
|
|
length.append(len(dataloader)) |
|
|
|
length.append(len(wrapper)) |
|
|
|
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) |