Browse Source

修改部分注释

tags/v0.4.10
yh 5 years ago
parent
commit
16388d5698
3 changed files with 7 additions and 5 deletions
  1. +2
    -2
      fastNLP/core/trainer.py
  2. +3
    -3
      fastNLP/core/utils.py
  3. +2
    -0
      test/core/test_utils.py

+ 2
- 2
fastNLP/core/trainer.py View File

@@ -355,7 +355,7 @@ class Trainer(object):
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有
效。
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用了nn.DataParallel,这里也只保存模型。
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
@@ -366,7 +366,7 @@ class Trainer(object):

2. torch.device:将模型装载到torch.device上。

3. int: 将使用device_id为值的gpu进行训练
3. int: 将使用该gpu进行训练

4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。



+ 3
- 3
fastNLP/core/utils.py View File

@@ -200,13 +200,13 @@ def _move_model_to_device(model, device):
else:
if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu`.")
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")

if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")

if isinstance(device, int):
assert device>-1, "device can only be positive integer"
assert device>-1, "device can only be non-negative integer"
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
device)
device = torch.device('cuda:{}'.format(device))
@@ -227,7 +227,7 @@ def _move_model_to_device(model, device):
assert list(types)[0] == int, "Only int supported for multiple devices."
assert len(set(device))==len(device), "Duplicated device id found in device."
for d in device:
assert d>-1, "Only positive device id allowed."
assert d>-1, "Only non-negative device id allowed."
if len(device)>1:
output_device = device[0]
model = nn.DataParallel(model, device_ids=device, output_device=output_device)


+ 2
- 0
test/core/test_utils.py View File

@@ -33,6 +33,8 @@ class TestMoveModelDeivce(unittest.TestCase):
assert model.param.device == torch.device('cuda:0')
with self.assertRaises(Exception):
_move_model_to_device(model, 'cuda:1000')
# 测试None
model = _move_model_to_device(model, None)

def test_case2(self):
# 测试使用int初始化


Loading…
Cancel
Save