Browse Source

更新 'gpu_mnist_example/parallel_train.py'

liuzx
liuzx 3 months ago
parent
commit
3b5143d90c
1 changed files with 4 additions and 0 deletions
  1. +4
    -0
      gpu_mnist_example/parallel_train.py

+ 4
- 0
gpu_mnist_example/parallel_train.py View File

@@ -43,10 +43,14 @@ WORKERS = 0 # dataloder线程数
# 检查可用GPU数量
if torch.cuda.device_count() < 2:
raise RuntimeError("需要至少2块GPU,但当前只有 {} 块".format(torch.cuda.device_count()))
else:
print('当前有 {} 块GPU'.format(torch.cuda.device_count()))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. 初始化模型并并行化
model = Model().to(device)

print('开始进行并行训练!')
model = nn.DataParallel(model, device_ids=[0, 1]) # 使用GPU 0和1

optimizer = SGD(model.parameters(), lr=1e-1)


Loading…
Cancel
Save