Browse Source

Tester中predict function的DataParallel并行

tags/v0.4.10
yh 5 years ago
parent
commit
b0fe264e42
3 changed files with 94 additions and 26 deletions
  1. +88
    -0
      fastNLP/core/_parallel_utils.py
  2. +6
    -5
      fastNLP/core/tester.py
  3. +0
    -21
      fastNLP/core/utils.py

+ 88
- 0
fastNLP/core/_parallel_utils.py View File

@@ -0,0 +1,88 @@

import threading
import torch
from torch.nn.parallel.parallel_apply import get_a_var

from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate


def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
r"""Applies each `module` in :attr:`modules` in parallel on arguments
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
on each of :attr:`devices`.

:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
:attr:`devices` (if given) should all have same length. Moreover, each
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)

lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()

def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)
output = getattr(module, func_name)(*input, **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e

if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]

for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs


def _data_parallel_wrapper(func_name, device_ids, output_device):
"""
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数

:param str, func_name: 对network中的这个函数进行多卡运行
:param device_ids: nn.DataParallel中的device_ids
:param output_device: nn.DataParallel中的output_device
:return:
"""
def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
return getattr(network, func_name)(*inputs[0], **kwargs[0])
replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device)
return wrapper

+ 6
- 5
fastNLP/core/tester.py View File

@@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation


"""
import warnings

import torch
import torch.nn as nn

@@ -48,7 +46,8 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from .utils import _data_parallel_wrapper
from ._parallel_utils import _data_parallel_wrapper
from functools import partial

__all__ = [
"Tester"
@@ -111,8 +110,10 @@ class Tester(object):
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and
callable(self._model.module.predict)):
if isinstance(self._model, nn.DataParallel):
self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids,
self._model.output_device)
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
self._model.device_ids,
self._model.output_device),
network=self._model.module)
self._predict_func = self._model.module.predict
else:
self._predict_func = self._model.predict


+ 0
- 21
fastNLP/core/utils.py View File

@@ -16,9 +16,6 @@ from collections import Counter, namedtuple
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply

_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
@@ -279,24 +276,6 @@ def _move_model_to_device(model, device):
model = model.to(device)
return model

def _data_parallel_wrapper(func, device_ids, output_device):
"""
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数

:param func: callable
:param device_ids: nn.DataParallel中的device_ids
:param inputs:
:param kwargs:
:return:
"""
def wrapper(*inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
return func(*inputs[0], **kwargs[0])
replicas = replicate(func, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, inputs, kwargs)
return gather(outputs, output_device)
return wrapper


def _get_model_device(model):


Loading…
Cancel
Save