|
|
@@ -105,18 +105,16 @@ class Tester(object): |
|
|
|
self.data_iterator = data |
|
|
|
else: |
|
|
|
raise TypeError("data type {} not support".format(type(data))) |
|
|
|
|
|
|
|
# 如果是DataParallel将没有办法使用predict方法 |
|
|
|
if isinstance(self._model, nn.DataParallel): |
|
|
|
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): |
|
|
|
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," |
|
|
|
" while DataParallel has no predict() function.") |
|
|
|
self._model = self._model.module |
|
|
|
|
|
|
|
|
|
|
|
# check predict |
|
|
|
if hasattr(self._model, 'predict') and callable(self._model.predict): |
|
|
|
self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, |
|
|
|
self._model.output_device) |
|
|
|
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ |
|
|
|
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and |
|
|
|
callable(self._model.module.predict)): |
|
|
|
if isinstance(self._model, nn.DataParallel): |
|
|
|
self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, |
|
|
|
self._model.output_device) |
|
|
|
else: |
|
|
|
self._predict_func = self._model.predict |
|
|
|
else: |
|
|
|
if isinstance(model, nn.DataParallel): |
|
|
|
self._predict_func = self._model.module.forward |
|
|
|