Browse Source

支持predict数据并行

tags/v0.4.10
yh 6 years ago
parent
commit
5f19601d20
2 changed files with 26 additions and 7 deletions
  1. +4
    -6
      fastNLP/core/tester.py
  2. +22
    -1
      fastNLP/core/utils.py

+ 4
- 6
fastNLP/core/tester.py View File

@@ -48,6 +48,7 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device
from .utils import _data_parallel_wrapper


__all__ = [ __all__ = [
"Tester" "Tester"
@@ -113,12 +114,9 @@ class Tester(object):
self._model = self._model.module self._model = self._model.module
# check predict # check predict
if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
if hasattr(self._model, 'predict') and callable(self._model.predict):
self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids,
self._model.output_device)
else: else:
if isinstance(model, nn.DataParallel): if isinstance(model, nn.DataParallel):
self._predict_func = self._model.module.forward self._predict_func = self._model.module.forward


+ 22
- 1
fastNLP/core/utils.py View File

@@ -16,7 +16,9 @@ from collections import Counter, namedtuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn

from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply


_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs']) 'varargs'])
@@ -277,6 +279,25 @@ def _move_model_to_device(model, device):
model = model.to(device) model = model.to(device)
return model return model


def _data_parallel_wrapper(func, device_ids, output_device):
"""
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数

:param func: callable
:param device_ids: nn.DataParallel中的device_ids
:param inputs:
:param kwargs:
:return:
"""
def wrapper(*inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
return func(*inputs[0], **kwargs[0])
replicas = replicate(func, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, inputs, kwargs)
return gather(outputs, output_device)
return wrapper



def _get_model_device(model): def _get_model_device(model):
""" """


Loading…
Cancel
Save