Browse Source

update documents in predictor

tags/v0.4.10
xuyige 6 years ago
parent
commit
8e7a604b29
1 changed files with 12 additions and 10 deletions
  1. +12
    -10
      fastNLP/core/predictor.py

+ 12
- 10
fastNLP/core/predictor.py View File

@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device

class Predictor(object):
"""
An interface for predicting outputs based on trained models.
一个根据训练模型预测输出的预测器(Predictor)

It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP.
This class does not share any operations with Trainer and Tester.
Currently, Predictor does not support GPU.
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
:param torch.nn.Module network: 用来完成预测任务的模型
"""

def __init__(self, network):
@@ -30,18 +30,19 @@ class Predictor(object):
self.batch_size = 1
self.batch_output = []

def predict(self, data, seq_len_field_name=None):
"""Perform inference using the trained model.
def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference.

:param data: a DataSet object.
:param str seq_len_field_name: field name indicating sequence lengths
:return: list of batch outputs
:param fastNLP.DataSet data: 待预测的数据集
:param str seq_len_field_name: 表示序列长度信息的field名字
:return: dict dict里面的内容为模型预测的结果
"""
if not isinstance(data, DataSet):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))

prev_training = self.network.training
self.network.eval()
network_device = _get_model_device(self.network)
batch_output = defaultdict(list)
@@ -74,4 +75,5 @@ class Predictor(object):
else:
batch_output[key].append(value)

self.network.train(prev_training)
return batch_output

Loading…
Cancel
Save