|
@@ -366,7 +366,7 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
2. torch.device:将模型装载到torch.device上。 |
|
|
2. torch.device:将模型装载到torch.device上。 |
|
|
|
|
|
|
|
|
3. int: 将使用device_id为该值的gpu进行训练 |
|
|
|
|
|
|
|
|
3. int: 将使用该device的gpu进行训练 |
|
|
|
|
|
|
|
|
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 |
|
|
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 |
|
|
|
|
|
|
|
|