diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py new file mode 100644 index 00000000..2d6a7380 --- /dev/null +++ b/fastNLP/core/predictor.py @@ -0,0 +1,79 @@ +""" + ..todo:: + 检查这个类是否需要 +""" +from collections import defaultdict + +import torch + +from . import DataSetIter +from . import DataSet +from . import SequentialSampler +from .utils import _build_args, _move_dict_value_to_device, _get_model_device + + +class Predictor(object): + """ + 一个根据训练模型预测输出的预测器(Predictor) + + 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 + 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 + + :param torch.nn.Module network: 用来完成预测任务的模型 + """ + + def __init__(self, network): + if not isinstance(network, torch.nn.Module): + raise ValueError( + "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) + self.network = network + self.batch_size = 1 + self.batch_output = [] + + def predict(self, data: DataSet, seq_len_field_name=None): + """用已经训练好的模型进行inference. + + :param fastNLP.DataSet data: 待预测的数据集 + :param str seq_len_field_name: 表示序列长度信息的field名字 + :return: dict dict里面的内容为模型预测的结果 + """ + if not isinstance(data, DataSet): + raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) + if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: + raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) + + prev_training = self.network.training + self.network.eval() + network_device = _get_model_device(self.network) + batch_output = defaultdict(list) + data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) + + if hasattr(self.network, "predict"): + predict_func = self.network.predict + else: + predict_func = self.network.forward + + with torch.no_grad(): + for batch_x, _ in data_iterator: + _move_dict_value_to_device(batch_x, _, device=network_device) + refined_batch_x = _build_args(predict_func, **batch_x) + prediction = predict_func(**refined_batch_x) + + if seq_len_field_name is not None: + seq_lens = batch_x[seq_len_field_name].tolist() + + for key, value in prediction.items(): + value = value.cpu().numpy() + if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): + batch_output[key].extend(value.tolist()) + else: + if seq_len_field_name is not None: + tmp_batch = [] + for idx, seq_len in enumerate(seq_lens): + tmp_batch.append(value[idx, :seq_len]) + batch_output[key].extend(tmp_batch) + else: + batch_output[key].append(value) + + self.network.train(prev_training) + return batch_output diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py new file mode 100644 index 00000000..701353dc --- /dev/null +++ b/test/core/test_predictor.py @@ -0,0 +1,48 @@ +import unittest +from collections import defaultdict + +import numpy as np +import torch + +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.predictor import Predictor + + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + + +class LinearModel(torch.nn.Module): + def __init__(self): + super(LinearModel, self).__init__() + self.linear = torch.nn.Linear(2, 1) + + def forward(self, x): + return {"predict": self.linear(x)} + + +class TestPredictor(unittest.TestCase): + def test_simple(self): + model = LinearModel() + predictor = Predictor(model) + data = prepare_fake_dataset() + data.set_input("x") + ans = predictor.predict(data) + self.assertTrue(isinstance(ans, defaultdict)) + self.assertTrue("predict" in ans) + self.assertTrue(isinstance(ans["predict"], list)) + + def test_sequence(self): + # test sequence input/output + pass