|
|
@@ -0,0 +1,79 @@ |
|
|
|
""" |
|
|
|
..todo:: |
|
|
|
检查这个类是否需要 |
|
|
|
""" |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
from . import DataSetIter |
|
|
|
from . import DataSet |
|
|
|
from . import SequentialSampler |
|
|
|
from .utils import _build_args, _move_dict_value_to_device, _get_model_device |
|
|
|
|
|
|
|
|
|
|
|
class Predictor(object): |
|
|
|
""" |
|
|
|
一个根据训练模型预测输出的预测器(Predictor) |
|
|
|
|
|
|
|
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 |
|
|
|
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 |
|
|
|
|
|
|
|
:param torch.nn.Module network: 用来完成预测任务的模型 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network): |
|
|
|
if not isinstance(network, torch.nn.Module): |
|
|
|
raise ValueError( |
|
|
|
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) |
|
|
|
self.network = network |
|
|
|
self.batch_size = 1 |
|
|
|
self.batch_output = [] |
|
|
|
|
|
|
|
def predict(self, data: DataSet, seq_len_field_name=None): |
|
|
|
"""用已经训练好的模型进行inference. |
|
|
|
|
|
|
|
:param fastNLP.DataSet data: 待预测的数据集 |
|
|
|
:param str seq_len_field_name: 表示序列长度信息的field名字 |
|
|
|
:return: dict dict里面的内容为模型预测的结果 |
|
|
|
""" |
|
|
|
if not isinstance(data, DataSet): |
|
|
|
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) |
|
|
|
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: |
|
|
|
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) |
|
|
|
|
|
|
|
prev_training = self.network.training |
|
|
|
self.network.eval() |
|
|
|
network_device = _get_model_device(self.network) |
|
|
|
batch_output = defaultdict(list) |
|
|
|
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) |
|
|
|
|
|
|
|
if hasattr(self.network, "predict"): |
|
|
|
predict_func = self.network.predict |
|
|
|
else: |
|
|
|
predict_func = self.network.forward |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
for batch_x, _ in data_iterator: |
|
|
|
_move_dict_value_to_device(batch_x, _, device=network_device) |
|
|
|
refined_batch_x = _build_args(predict_func, **batch_x) |
|
|
|
prediction = predict_func(**refined_batch_x) |
|
|
|
|
|
|
|
if seq_len_field_name is not None: |
|
|
|
seq_lens = batch_x[seq_len_field_name].tolist() |
|
|
|
|
|
|
|
for key, value in prediction.items(): |
|
|
|
value = value.cpu().numpy() |
|
|
|
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): |
|
|
|
batch_output[key].extend(value.tolist()) |
|
|
|
else: |
|
|
|
if seq_len_field_name is not None: |
|
|
|
tmp_batch = [] |
|
|
|
for idx, seq_len in enumerate(seq_lens): |
|
|
|
tmp_batch.append(value[idx, :seq_len]) |
|
|
|
batch_output[key].extend(tmp_batch) |
|
|
|
else: |
|
|
|
batch_output[key].append(value) |
|
|
|
|
|
|
|
self.network.train(prev_training) |
|
|
|
return batch_output |