From 3b5143d90cfe5e710563242bb3260e6fa09d00b1 Mon Sep 17 00:00:00 2001 From: liuzx Date: Wed, 30 Jul 2025 11:00:33 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'gpu=5Fmnist=5Fexample/par?= =?UTF-8?q?allel=5Ftrain.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gpu_mnist_example/parallel_train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/gpu_mnist_example/parallel_train.py b/gpu_mnist_example/parallel_train.py index 4d2ac02..1f95d6f 100644 --- a/gpu_mnist_example/parallel_train.py +++ b/gpu_mnist_example/parallel_train.py @@ -43,10 +43,14 @@ WORKERS = 0 # dataloder线程数 # 检查可用GPU数量 if torch.cuda.device_count() < 2: raise RuntimeError("需要至少2块GPU,但当前只有 {} 块".format(torch.cuda.device_count())) +else: + print('当前有 {} 块GPU'.format(torch.cuda.device_count())) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 2. 初始化模型并并行化 model = Model().to(device) + +print('开始进行并行训练!') model = nn.DataParallel(model, device_ids=[0, 1]) # 使用GPU 0和1 optimizer = SGD(model.parameters(), lr=1e-1)