|
|
@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device |
|
|
|
|
|
|
|
class Predictor(object): |
|
|
|
""" |
|
|
|
An interface for predicting outputs based on trained models. |
|
|
|
一个根据训练模型预测输出的预测器(Predictor) |
|
|
|
|
|
|
|
It does not care about evaluations of the model, which is different from Tester. |
|
|
|
This is a high-level model wrapper to be called by FastNLP. |
|
|
|
This class does not share any operations with Trainer and Tester. |
|
|
|
Currently, Predictor does not support GPU. |
|
|
|
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 |
|
|
|
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 |
|
|
|
|
|
|
|
:param torch.nn.Module network: 用来完成预测任务的模型 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, network): |
|
|
@@ -30,18 +30,19 @@ class Predictor(object): |
|
|
|
self.batch_size = 1 |
|
|
|
self.batch_output = [] |
|
|
|
|
|
|
|
def predict(self, data, seq_len_field_name=None): |
|
|
|
"""Perform inference using the trained model. |
|
|
|
def predict(self, data: DataSet, seq_len_field_name=None): |
|
|
|
"""用已经训练好的模型进行inference. |
|
|
|
|
|
|
|
:param data: a DataSet object. |
|
|
|
:param str seq_len_field_name: field name indicating sequence lengths |
|
|
|
:return: list of batch outputs |
|
|
|
:param fastNLP.DataSet data: 待预测的数据集 |
|
|
|
:param str seq_len_field_name: 表示序列长度信息的field名字 |
|
|
|
:return: dict dict里面的内容为模型预测的结果 |
|
|
|
""" |
|
|
|
if not isinstance(data, DataSet): |
|
|
|
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) |
|
|
|
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: |
|
|
|
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) |
|
|
|
|
|
|
|
prev_training = self.network.training |
|
|
|
self.network.eval() |
|
|
|
network_device = _get_model_device(self.network) |
|
|
|
batch_output = defaultdict(list) |
|
|
@@ -74,4 +75,5 @@ class Predictor(object): |
|
|
|
else: |
|
|
|
batch_output[key].append(value) |
|
|
|
|
|
|
|
self.network.train(prev_training) |
|
|
|
return batch_output |