|
|
@@ -28,12 +28,12 @@ class TestUnrepeatedSampler: |
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3]) |
|
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) |
|
|
|
@pytest.mark.parametrize('shuffle', [False, True]) |
|
|
|
def test_multi(self, num_replica, num_of_data, shuffle): |
|
|
|
def test_multi(self, num_replicas, num_of_data, shuffle): |
|
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data) |
|
|
|
samplers = [] |
|
|
|
for i in range(num_replica): |
|
|
|
for i in range(num_replicas): |
|
|
|
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle) |
|
|
|
sampler.set_distributed(num_replica, rank=i) |
|
|
|
sampler.set_distributed(num_replicas, rank=i) |
|
|
|
samplers.append(sampler) |
|
|
|
|
|
|
|
indexes = list(chain(*samplers)) |
|
|
@@ -52,12 +52,12 @@ class TestUnrepeatedSortedSampler: |
|
|
|
|
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3]) |
|
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) |
|
|
|
def test_multi(self, num_replica, num_of_data): |
|
|
|
def test_multi(self, num_replicas, num_of_data): |
|
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data) |
|
|
|
samplers = [] |
|
|
|
for i in range(num_replica): |
|
|
|
for i in range(num_replicas): |
|
|
|
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data) |
|
|
|
sampler.set_distributed(num_replica, rank=i) |
|
|
|
sampler.set_distributed(num_replicas, rank=i) |
|
|
|
samplers.append(sampler) |
|
|
|
|
|
|
|
# 保证顺序是没乱的 |
|
|
@@ -83,12 +83,12 @@ class TestUnrepeatedSequentialSampler: |
|
|
|
|
|
|
|
@pytest.mark.parametrize('num_replicas', [2, 3]) |
|
|
|
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) |
|
|
|
def test_multi(self, num_replica, num_of_data): |
|
|
|
def test_multi(self, num_replicas, num_of_data): |
|
|
|
data = DatasetWithVaryLength(num_of_data=num_of_data) |
|
|
|
samplers = [] |
|
|
|
for i in range(num_replica): |
|
|
|
for i in range(num_replicas): |
|
|
|
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data) |
|
|
|
sampler.set_distributed(num_replica, rank=i) |
|
|
|
sampler.set_distributed(num_replicas, rank=i) |
|
|
|
samplers.append(sampler) |
|
|
|
|
|
|
|
# 保证顺序是没乱的 |
|
|
|