Browse Source

修改部分注释

tags/v0.4.10
yh 6 years ago
parent
commit
16388d5698
3 changed files with 7 additions and 5 deletions
  1. +2
    -2
      fastNLP/core/trainer.py
  2. +3
    -3
      fastNLP/core/utils.py
  3. +2
    -0
      test/core/test_utils.py

+ 2
- 2
fastNLP/core/trainer.py View File

@@ -355,7 +355,7 @@ class Trainer(object):
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有
效。 效。
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模 :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
型。保存的时候不仅保存了参数,还保存了模型结构。即便使用了nn.DataParallel,这里也只保存模型。
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 :param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
@@ -366,7 +366,7 @@ class Trainer(object):


2. torch.device:将模型装载到torch.device上。 2. torch.device:将模型装载到torch.device上。


3. int: 将使用device_id为值的gpu进行训练
3. int: 将使用该gpu进行训练


4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。




+ 3
- 3
fastNLP/core/utils.py View File

@@ -200,13 +200,13 @@ def _move_model_to_device(model, device):
else: else:
if not torch.cuda.is_available() and ( if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu`.")
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")


if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")


if isinstance(device, int): if isinstance(device, int):
assert device>-1, "device can only be positive integer"
assert device>-1, "device can only be non-negative integer"
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
device) device)
device = torch.device('cuda:{}'.format(device)) device = torch.device('cuda:{}'.format(device))
@@ -227,7 +227,7 @@ def _move_model_to_device(model, device):
assert list(types)[0] == int, "Only int supported for multiple devices." assert list(types)[0] == int, "Only int supported for multiple devices."
assert len(set(device))==len(device), "Duplicated device id found in device." assert len(set(device))==len(device), "Duplicated device id found in device."
for d in device: for d in device:
assert d>-1, "Only positive device id allowed."
assert d>-1, "Only non-negative device id allowed."
if len(device)>1: if len(device)>1:
output_device = device[0] output_device = device[0]
model = nn.DataParallel(model, device_ids=device, output_device=output_device) model = nn.DataParallel(model, device_ids=device, output_device=output_device)


+ 2
- 0
test/core/test_utils.py View File

@@ -33,6 +33,8 @@ class TestMoveModelDeivce(unittest.TestCase):
assert model.param.device == torch.device('cuda:0') assert model.param.device == torch.device('cuda:0')
with self.assertRaises(Exception): with self.assertRaises(Exception):
_move_model_to_device(model, 'cuda:1000') _move_model_to_device(model, 'cuda:1000')
# 测试None
model = _move_model_to_device(model, None)


def test_case2(self): def test_case2(self):
# 测试使用int初始化 # 测试使用int初始化


Loading…
Cancel
Save