diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index ab81cefc..def18a15 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -96,7 +96,7 @@ class TestAccuracy: metric_kwargs=metric_kwargs, sklearn_metric=sklearn_accuracy, ), - [(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] + [(rank, processes, torch.device(f'cuda:{rank+4}')) for rank in range(processes)] ) else: device = torch.device(