diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 6a0fdb9a..4fa31fd2 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -105,18 +105,16 @@ class Tester(object): self.data_iterator = data else: raise TypeError("data type {} not support".format(type(data))) - - # 如果是DataParallel将没有办法使用predict方法 - if isinstance(self._model, nn.DataParallel): - if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): - warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," - " while DataParallel has no predict() function.") - self._model = self._model.module - + # check predict - if hasattr(self._model, 'predict') and callable(self._model.predict): - self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, - self._model.output_device) + if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ + (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and + callable(self._model.module.predict)): + if isinstance(self._model, nn.DataParallel): + self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, + self._model.output_device) + else: + self._predict_func = self._model.predict else: if isinstance(model, nn.DataParallel): self._predict_func = self._model.module.forward