|
|
@@ -96,7 +96,7 @@ class TestAccuracy: |
|
|
|
metric_kwargs=metric_kwargs, |
|
|
|
sklearn_metric=sklearn_accuracy, |
|
|
|
), |
|
|
|
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] |
|
|
|
[(rank, processes, torch.device(f'cuda:{rank+4}')) for rank in range(processes)] |
|
|
|
) |
|
|
|
else: |
|
|
|
device = torch.device( |
|
|
|