diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py new file mode 100644 index 00000000..4a7757d3 --- /dev/null +++ b/fastNLP/core/_parallel_utils.py @@ -0,0 +1,88 @@ + +import threading +import torch +from torch.nn.parallel.parallel_apply import get_a_var + +from torch.nn.parallel.scatter_gather import scatter_kwargs, gather +from torch.nn.parallel.replicate import replicate + + +def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): + r"""Applies each `module` in :attr:`modules` in parallel on arguments + contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) + on each of :attr:`devices`. + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, kwargs, device=None): + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + output = getattr(module, func_name)(*input, **kwargs) + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in + enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs + + +def _data_parallel_wrapper(func_name, device_ids, output_device): + """ + 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 + + :param str, func_name: 对network中的这个函数进行多卡运行 + :param device_ids: nn.DataParallel中的device_ids + :param output_device: nn.DataParallel中的output_device + :return: + """ + def wrapper(network, *inputs, **kwargs): + inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) + if len(device_ids) == 1: + return getattr(network, func_name)(*inputs[0], **kwargs[0]) + replicas = replicate(network, device_ids[:len(inputs)]) + outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) + return gather(outputs, output_device) + return wrapper diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 68950c10..7048d0ae 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -32,8 +32,6 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation """ -import warnings - import torch import torch.nn as nn @@ -48,7 +46,8 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device -from .utils import _data_parallel_wrapper +from ._parallel_utils import _data_parallel_wrapper +from functools import partial __all__ = [ "Tester" @@ -111,8 +110,10 @@ class Tester(object): (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and callable(self._model.module.predict)): if isinstance(self._model, nn.DataParallel): - self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, - self._model.output_device) + self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', + self._model.device_ids, + self._model.output_device), + network=self._model.module) self._predict_func = self._model.module.predict else: self._predict_func = self._model.predict diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 8fe764f8..490f9f8f 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -16,9 +16,6 @@ from collections import Counter, namedtuple import numpy as np import torch import torch.nn as nn -from torch.nn.parallel.scatter_gather import scatter_kwargs, gather -from torch.nn.parallel.replicate import replicate -from torch.nn.parallel.parallel_apply import parallel_apply _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -279,24 +276,6 @@ def _move_model_to_device(model, device): model = model.to(device) return model -def _data_parallel_wrapper(func, device_ids, output_device): - """ - 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 - - :param func: callable - :param device_ids: nn.DataParallel中的device_ids - :param inputs: - :param kwargs: - :return: - """ - def wrapper(*inputs, **kwargs): - inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) - if len(device_ids) == 1: - return func(*inputs[0], **kwargs[0]) - replicas = replicate(func, device_ids[:len(inputs)]) - outputs = parallel_apply(replicas, inputs, kwargs) - return gather(outputs, output_device) - return wrapper def _get_model_device(model):