From 8445bdbc793c69e998efd9381229820ae9a5ba9d Mon Sep 17 00:00:00 2001 From: ChenXin Date: Sun, 25 Aug 2019 16:57:47 +0800 Subject: [PATCH] delete predictor.py --- fastNLP/core/predictor.py | 79 ------------------------------------- test/core/test_predictor.py | 48 ---------------------- 2 files changed, 127 deletions(-) delete mode 100644 fastNLP/core/predictor.py delete mode 100644 test/core/test_predictor.py diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py deleted file mode 100644 index 2d6a7380..00000000 --- a/fastNLP/core/predictor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" - ..todo:: - 检查这个类是否需要 -""" -from collections import defaultdict - -import torch - -from . import DataSetIter -from . import DataSet -from . import SequentialSampler -from .utils import _build_args, _move_dict_value_to_device, _get_model_device - - -class Predictor(object): - """ - 一个根据训练模型预测输出的预测器(Predictor) - - 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 - 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 - - :param torch.nn.Module network: 用来完成预测任务的模型 - """ - - def __init__(self, network): - if not isinstance(network, torch.nn.Module): - raise ValueError( - "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) - self.network = network - self.batch_size = 1 - self.batch_output = [] - - def predict(self, data: DataSet, seq_len_field_name=None): - """用已经训练好的模型进行inference. - - :param fastNLP.DataSet data: 待预测的数据集 - :param str seq_len_field_name: 表示序列长度信息的field名字 - :return: dict dict里面的内容为模型预测的结果 - """ - if not isinstance(data, DataSet): - raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) - if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: - raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) - - prev_training = self.network.training - self.network.eval() - network_device = _get_model_device(self.network) - batch_output = defaultdict(list) - data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) - - if hasattr(self.network, "predict"): - predict_func = self.network.predict - else: - predict_func = self.network.forward - - with torch.no_grad(): - for batch_x, _ in data_iterator: - _move_dict_value_to_device(batch_x, _, device=network_device) - refined_batch_x = _build_args(predict_func, **batch_x) - prediction = predict_func(**refined_batch_x) - - if seq_len_field_name is not None: - seq_lens = batch_x[seq_len_field_name].tolist() - - for key, value in prediction.items(): - value = value.cpu().numpy() - if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): - batch_output[key].extend(value.tolist()) - else: - if seq_len_field_name is not None: - tmp_batch = [] - for idx, seq_len in enumerate(seq_lens): - tmp_batch.append(value[idx, :seq_len]) - batch_output[key].extend(tmp_batch) - else: - batch_output[key].append(value) - - self.network.train(prev_training) - return batch_output diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py deleted file mode 100644 index 701353dc..00000000 --- a/test/core/test_predictor.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest -from collections import defaultdict - -import numpy as np -import torch - -from fastNLP.core.dataset import DataSet -from fastNLP.core.instance import Instance -from fastNLP.core.predictor import Predictor - - -def prepare_fake_dataset(): - mean = np.array([-3, -3]) - cov = np.array([[1, 0], [0, 1]]) - class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) - - mean = np.array([3, 3]) - cov = np.array([[1, 0], [0, 1]]) - class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) - - data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + - [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) - return data_set - - -class LinearModel(torch.nn.Module): - def __init__(self): - super(LinearModel, self).__init__() - self.linear = torch.nn.Linear(2, 1) - - def forward(self, x): - return {"predict": self.linear(x)} - - -class TestPredictor(unittest.TestCase): - def test_simple(self): - model = LinearModel() - predictor = Predictor(model) - data = prepare_fake_dataset() - data.set_input("x") - ans = predictor.predict(data) - self.assertTrue(isinstance(ans, defaultdict)) - self.assertTrue("predict" in ans) - self.assertTrue(isinstance(ans["predict"], list)) - - def test_sequence(self): - # test sequence input/output - pass