From f68b2c5382b6411d9aca8b674194b94896d296e9 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 1 Jul 2019 00:33:31 +0800 Subject: [PATCH] =?UTF-8?q?Tester=E6=94=AF=E6=8C=81predict=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=B9=B6=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/tester.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 6a0fdb9a..4fa31fd2 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -105,18 +105,16 @@ class Tester(object): self.data_iterator = data else: raise TypeError("data type {} not support".format(type(data))) - - # 如果是DataParallel将没有办法使用predict方法 - if isinstance(self._model, nn.DataParallel): - if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): - warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," - " while DataParallel has no predict() function.") - self._model = self._model.module - + # check predict - if hasattr(self._model, 'predict') and callable(self._model.predict): - self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids, - self._model.output_device) + if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ + (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and + callable(self._model.module.predict)): + if isinstance(self._model, nn.DataParallel): + self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, + self._model.output_device) + else: + self._predict_func = self._model.predict else: if isinstance(model, nn.DataParallel): self._predict_func = self._model.module.forward