Browse Source

Tester数据并行

tags/v0.4.10
yh 5 years ago
parent
commit
3c984872d3
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      fastNLP/core/tester.py

+ 7
- 3
fastNLP/core/tester.py View File

@@ -111,15 +111,19 @@ class Tester(object):
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and
callable(self._model.module.predict)):
if isinstance(self._model, nn.DataParallel):
self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids,
self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids,
self._model.output_device)
self._predict_func = self._model.module.predict
else:
self._predict_func = self._model.predict
self._predict_func_wrapper = self._model.predict
else:
if isinstance(model, nn.DataParallel):
if isinstance(self._model, nn.DataParallel):
self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:
self._predict_func = self._model.forward
self._predict_func_wrapper = self._model.forward
def test(self):
"""开始进行验证,并返回验证结果。
@@ -176,7 +180,7 @@ class Tester(object):
def _data_forward(self, func, x):
"""A forward pass of the model. """
x = _build_args(func, **x)
y = func(**x)
y = self._predict_func_wrapper(**x)
return y
def _format_eval_results(self, results):


Loading…
Cancel
Save