|
-
- import threading
- import torch
- from torch.nn.parallel.parallel_apply import get_a_var
-
- from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
- from torch.nn.parallel.replicate import replicate
-
-
- def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
- r"""Applies each `module` in :attr:`modules` in parallel on arguments
- contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
- on each of :attr:`devices`.
-
- :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
- :attr:`devices` (if given) should all have same length. Moreover, each
- element of :attr:`inputs` can either be a single object as the only argument
- to a module, or a collection of positional arguments.
- """
- assert len(modules) == len(inputs)
- if kwargs_tup is not None:
- assert len(modules) == len(kwargs_tup)
- else:
- kwargs_tup = ({},) * len(modules)
- if devices is not None:
- assert len(modules) == len(devices)
- else:
- devices = [None] * len(modules)
-
- lock = threading.Lock()
- results = {}
- grad_enabled = torch.is_grad_enabled()
-
- def _worker(i, module, input, kwargs, device=None):
- torch.set_grad_enabled(grad_enabled)
- if device is None:
- device = get_a_var(input).get_device()
- try:
- with torch.cuda.device(device):
- # this also avoids accidental slicing of `input` if it is a Tensor
- if not isinstance(input, (list, tuple)):
- input = (input,)
- output = getattr(module, func_name)(*input, **kwargs)
- with lock:
- results[i] = output
- except Exception as e:
- with lock:
- results[i] = e
-
- if len(modules) > 1:
- threads = [threading.Thread(target=_worker,
- args=(i, module, input, kwargs, device))
- for i, (module, input, kwargs, device) in
- enumerate(zip(modules, inputs, kwargs_tup, devices))]
-
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
- else:
- _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
-
- outputs = []
- for i in range(len(inputs)):
- output = results[i]
- if isinstance(output, Exception):
- raise output
- outputs.append(output)
- return outputs
-
-
- def _data_parallel_wrapper(func_name, device_ids, output_device):
- """
- 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数
-
- :param str, func_name: 对network中的这个函数进行多卡运行
- :param device_ids: nn.DataParallel中的device_ids
- :param output_device: nn.DataParallel中的output_device
- :return:
- """
- def wrapper(network, *inputs, **kwargs):
- inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
- if len(device_ids) == 1:
- return getattr(network, func_name)(*inputs[0], **kwargs[0])
- replicas = replicate(network, device_ids[:len(inputs)])
- outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
- return gather(outputs, output_device)
- return wrapper
|