|
|
@@ -148,7 +148,7 @@ class Tester(object): |
|
|
|
self._predict_func = self._model.predict |
|
|
|
self._predict_func_wrapper = self._model.predict |
|
|
|
else: |
|
|
|
if _model_contains_inner_module(model): |
|
|
|
if _model_contains_inner_module(self._model): |
|
|
|
self._predict_func_wrapper = self._model.forward |
|
|
|
self._predict_func = self._model.module.forward |
|
|
|
else: |
|
|
|