@@ -0,0 +1,88 @@ | |||||
import threading | |||||
import torch | |||||
from torch.nn.parallel.parallel_apply import get_a_var | |||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||||
from torch.nn.parallel.replicate import replicate | |||||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||||
r"""Applies each `module` in :attr:`modules` in parallel on arguments | |||||
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) | |||||
on each of :attr:`devices`. | |||||
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and | |||||
:attr:`devices` (if given) should all have same length. Moreover, each | |||||
element of :attr:`inputs` can either be a single object as the only argument | |||||
to a module, or a collection of positional arguments. | |||||
""" | |||||
assert len(modules) == len(inputs) | |||||
if kwargs_tup is not None: | |||||
assert len(modules) == len(kwargs_tup) | |||||
else: | |||||
kwargs_tup = ({},) * len(modules) | |||||
if devices is not None: | |||||
assert len(modules) == len(devices) | |||||
else: | |||||
devices = [None] * len(modules) | |||||
lock = threading.Lock() | |||||
results = {} | |||||
grad_enabled = torch.is_grad_enabled() | |||||
def _worker(i, module, input, kwargs, device=None): | |||||
torch.set_grad_enabled(grad_enabled) | |||||
if device is None: | |||||
device = get_a_var(input).get_device() | |||||
try: | |||||
with torch.cuda.device(device): | |||||
# this also avoids accidental slicing of `input` if it is a Tensor | |||||
if not isinstance(input, (list, tuple)): | |||||
input = (input,) | |||||
output = getattr(module, func_name)(*input, **kwargs) | |||||
with lock: | |||||
results[i] = output | |||||
except Exception as e: | |||||
with lock: | |||||
results[i] = e | |||||
if len(modules) > 1: | |||||
threads = [threading.Thread(target=_worker, | |||||
args=(i, module, input, kwargs, device)) | |||||
for i, (module, input, kwargs, device) in | |||||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||||
for thread in threads: | |||||
thread.start() | |||||
for thread in threads: | |||||
thread.join() | |||||
else: | |||||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||||
outputs = [] | |||||
for i in range(len(inputs)): | |||||
output = results[i] | |||||
if isinstance(output, Exception): | |||||
raise output | |||||
outputs.append(output) | |||||
return outputs | |||||
def _data_parallel_wrapper(func_name, device_ids, output_device): | |||||
""" | |||||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||||
:param str, func_name: 对network中的这个函数进行多卡运行 | |||||
:param device_ids: nn.DataParallel中的device_ids | |||||
:param output_device: nn.DataParallel中的output_device | |||||
:return: | |||||
""" | |||||
def wrapper(network, *inputs, **kwargs): | |||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||||
if len(device_ids) == 1: | |||||
return getattr(network, func_name)(*inputs[0], **kwargs[0]) | |||||
replicas = replicate(network, device_ids[:len(inputs)]) | |||||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||||
return gather(outputs, output_device) | |||||
return wrapper |
@@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||||
""" | """ | ||||
import warnings | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -48,7 +46,8 @@ from .utils import _move_dict_value_to_device | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from .utils import _data_parallel_wrapper | |||||
from ._parallel_utils import _data_parallel_wrapper | |||||
from functools import partial | |||||
__all__ = [ | __all__ = [ | ||||
"Tester" | "Tester" | ||||
@@ -111,8 +110,10 @@ class Tester(object): | |||||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | ||||
callable(self._model.module.predict)): | callable(self._model.module.predict)): | ||||
if isinstance(self._model, nn.DataParallel): | if isinstance(self._model, nn.DataParallel): | ||||
self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, | |||||
self._model.output_device) | |||||
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', | |||||
self._model.device_ids, | |||||
self._model.output_device), | |||||
network=self._model.module) | |||||
self._predict_func = self._model.module.predict | self._predict_func = self._model.module.predict | ||||
else: | else: | ||||
self._predict_func = self._model.predict | self._predict_func = self._model.predict | ||||
@@ -16,9 +16,6 @@ from collections import Counter, namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||||
from torch.nn.parallel.replicate import replicate | |||||
from torch.nn.parallel.parallel_apply import parallel_apply | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -279,24 +276,6 @@ def _move_model_to_device(model, device): | |||||
model = model.to(device) | model = model.to(device) | ||||
return model | return model | ||||
def _data_parallel_wrapper(func, device_ids, output_device): | |||||
""" | |||||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||||
:param func: callable | |||||
:param device_ids: nn.DataParallel中的device_ids | |||||
:param inputs: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
def wrapper(*inputs, **kwargs): | |||||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||||
if len(device_ids) == 1: | |||||
return func(*inputs[0], **kwargs[0]) | |||||
replicas = replicate(func, device_ids[:len(inputs)]) | |||||
outputs = parallel_apply(replicas, inputs, kwargs) | |||||
return gather(outputs, output_device) | |||||
return wrapper | |||||
def _get_model_device(model): | def _get_model_device(model): | ||||