|
|
@@ -3,10 +3,10 @@ from collections import defaultdict |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
from ..core.batch import Batch |
|
|
|
from ..core.dataset import DataSet |
|
|
|
from ..core.sampler import SequentialSampler |
|
|
|
from ..core.vocabulary import Vocabulary |
|
|
|
from fastNLP.core.batch import Batch |
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
class Processor(object): |
|
|
@@ -232,7 +232,7 @@ class SeqLenProcessor(Processor): |
|
|
|
return dataset |
|
|
|
|
|
|
|
|
|
|
|
from ..core.utils import _build_args |
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
|
|
|
|
|
|
|
|
class ModelProcessor(Processor): |
|
|
@@ -257,10 +257,7 @@ class ModelProcessor(Processor): |
|
|
|
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) |
|
|
|
|
|
|
|
batch_output = defaultdict(list) |
|
|
|
if hasattr(self.model, "predict"): |
|
|
|
predict_func = self.model.predict |
|
|
|
else: |
|
|
|
predict_func = self.model.forward |
|
|
|
predict_func = self.model.forward |
|
|
|
with torch.no_grad(): |
|
|
|
for batch_x, _ in data_iterator: |
|
|
|
refined_batch_x = _build_args(predict_func, **batch_x) |