|
|
@@ -111,15 +111,19 @@ class Tester(object): |
|
|
|
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and |
|
|
|
callable(self._model.module.predict)): |
|
|
|
if isinstance(self._model, nn.DataParallel): |
|
|
|
self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, |
|
|
|
self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, |
|
|
|
self._model.output_device) |
|
|
|
self._predict_func = self._model.module.predict |
|
|
|
else: |
|
|
|
self._predict_func = self._model.predict |
|
|
|
self._predict_func_wrapper = self._model.predict |
|
|
|
else: |
|
|
|
if isinstance(model, nn.DataParallel): |
|
|
|
if isinstance(self._model, nn.DataParallel): |
|
|
|
self._predict_func_wrapper = self._model.forward |
|
|
|
self._predict_func = self._model.module.forward |
|
|
|
else: |
|
|
|
self._predict_func = self._model.forward |
|
|
|
self._predict_func_wrapper = self._model.forward |
|
|
|
|
|
|
|
def test(self): |
|
|
|
"""开始进行验证,并返回验证结果。 |
|
|
@@ -176,7 +180,7 @@ class Tester(object): |
|
|
|
def _data_forward(self, func, x): |
|
|
|
"""A forward pass of the model. """ |
|
|
|
x = _build_args(func, **x) |
|
|
|
y = func(**x) |
|
|
|
y = self._predict_func_wrapper(**x) |
|
|
|
return y |
|
|
|
|
|
|
|
def _format_eval_results(self, results): |
|
|
|