Browse Source

Tester支持predict数据并行

tags/v0.4.10
yh 6 years ago
parent
commit
f68b2c5382
1 changed files with 9 additions and 11 deletions
  1. +9
    -11
      fastNLP/core/tester.py

+ 9
- 11
fastNLP/core/tester.py View File

@@ -105,18 +105,16 @@ class Tester(object):
self.data_iterator = data
else:
raise TypeError("data type {} not support".format(type(data)))
# 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel):
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'):
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function,"
" while DataParallel has no predict() function.")
self._model = self._model.module

# check predict
if hasattr(self._model, 'predict') and callable(self._model.predict):
self._predict_func = _data_parallel_wrapper(self._model.predict, self._model.device_ids,
self._model.output_device)
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and
callable(self._model.module.predict)):
if isinstance(self._model, nn.DataParallel):
self._predict_func = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids,
self._model.output_device)
else:
self._predict_func = self._model.predict
else:
if isinstance(model, nn.DataParallel):
self._predict_func = self._model.module.forward


Loading…
Cancel
Save