diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index ce016bb6..2d6a7380 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device class Predictor(object): """ - An interface for predicting outputs based on trained models. + 一个根据训练模型预测输出的预测器(Predictor) - It does not care about evaluations of the model, which is different from Tester. - This is a high-level model wrapper to be called by FastNLP. - This class does not share any operations with Trainer and Tester. - Currently, Predictor does not support GPU. + 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 + 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 + + :param torch.nn.Module network: 用来完成预测任务的模型 """ def __init__(self, network): @@ -30,18 +30,19 @@ class Predictor(object): self.batch_size = 1 self.batch_output = [] - def predict(self, data, seq_len_field_name=None): - """Perform inference using the trained model. + def predict(self, data: DataSet, seq_len_field_name=None): + """用已经训练好的模型进行inference. - :param data: a DataSet object. - :param str seq_len_field_name: field name indicating sequence lengths - :return: list of batch outputs + :param fastNLP.DataSet data: 待预测的数据集 + :param str seq_len_field_name: 表示序列长度信息的field名字 + :return: dict dict里面的内容为模型预测的结果 """ if not isinstance(data, DataSet): raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) + prev_training = self.network.training self.network.eval() network_device = _get_model_device(self.network) batch_output = defaultdict(list) @@ -74,4 +75,5 @@ class Predictor(object): else: batch_output[key].append(value) + self.network.train(prev_training) return batch_output