|
@@ -16,7 +16,9 @@ from collections import Counter, namedtuple |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
|
|
|
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather |
|
|
|
|
|
from torch.nn.parallel.replicate import replicate |
|
|
|
|
|
from torch.nn.parallel.parallel_apply import parallel_apply |
|
|
|
|
|
|
|
|
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', |
|
|
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', |
|
|
'varargs']) |
|
|
'varargs']) |
|
@@ -277,6 +279,25 @@ def _move_model_to_device(model, device): |
|
|
model = model.to(device) |
|
|
model = model.to(device) |
|
|
return model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
def _data_parallel_wrapper(func, device_ids, output_device): |
|
|
|
|
|
""" |
|
|
|
|
|
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 |
|
|
|
|
|
|
|
|
|
|
|
:param func: callable |
|
|
|
|
|
:param device_ids: nn.DataParallel中的device_ids |
|
|
|
|
|
:param inputs: |
|
|
|
|
|
:param kwargs: |
|
|
|
|
|
:return: |
|
|
|
|
|
""" |
|
|
|
|
|
def wrapper(*inputs, **kwargs): |
|
|
|
|
|
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) |
|
|
|
|
|
if len(device_ids) == 1: |
|
|
|
|
|
return func(*inputs[0], **kwargs[0]) |
|
|
|
|
|
replicas = replicate(func, device_ids[:len(inputs)]) |
|
|
|
|
|
outputs = parallel_apply(replicas, inputs, kwargs) |
|
|
|
|
|
return gather(outputs, output_device) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_model_device(model): |
|
|
def _get_model_device(model): |
|
|
""" |
|
|
""" |
|
|