diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 4cdd4ffb..6a0fdb9a 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -48,6 +48,7 @@ from .utils import _move_dict_value_to_device from .utils import _get_func_signature from .utils import _get_model_device from .utils import _move_model_to_device +from .utils import _data_parallel_wrapper __all__ = [ "Tester" @@ -113,12 +114,9 @@ class Tester(object): self._model = self._model.module # check predict - if hasattr(self._model, 'predict'): - self._predict_func = self._model.predict - if not callable(self._predict_func): - _model_name = model.__class__.__name__ - raise TypeError(f"`{_model_name}.predict` must be callable to be used " - f"for evaluation, not `{type(self._predict_func)}`.") + if hasattr(self._model, 'predict') and callable(self._model.predict): + self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, + self._model.output_device) else: if isinstance(model, nn.DataParallel): self._predict_func = self._model.module.forward diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d26df966..8fe764f8 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -16,7 +16,9 @@ from collections import Counter, namedtuple import numpy as np import torch import torch.nn as nn - +from torch.nn.parallel.scatter_gather import scatter_kwargs, gather +from torch.nn.parallel.replicate import replicate +from torch.nn.parallel.parallel_apply import parallel_apply _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -277,6 +279,25 @@ def _move_model_to_device(model, device): model = model.to(device) return model +def _data_parallel_wrapper(func, device_ids, output_device): + """ + 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 + + :param func: callable + :param device_ids: nn.DataParallel中的device_ids + :param inputs: + :param kwargs: + :return: + """ + def wrapper(*inputs, **kwargs): + inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) + if len(device_ids) == 1: + return func(*inputs[0], **kwargs[0]) + replicas = replicate(func, device_ids[:len(inputs)]) + outputs = parallel_apply(replicas, inputs, kwargs) + return gather(outputs, output_device) + return wrapper + def _get_model_device(model): """