From baac29cfa09299f910d1a33ca39982bee79688a7 Mon Sep 17 00:00:00 2001 From: yunfan Date: Wed, 17 Oct 2018 10:59:11 +0800 Subject: [PATCH] fix tests --- fastNLP/core/instance.py | 2 +- fastNLP/fastnlp.py | 3 ++- fastNLP/loader/dataset_loader.py | 13 +++++++++++++ test/core/test_predictor.py | 1 + test/model/test_cws.py | 9 ++++++--- test/model/test_seq_label.py | 20 +++++++++++++------- 6 files changed, 36 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index b01c336b..a4eca1aa 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -20,7 +20,7 @@ class Instance(object): if old_name in self.indexes: self.indexes[new_name] = self.indexes.pop(old_name) else: - print("error, no such field: {}".format(old_name)) + raise KeyError("error, no such field: {}".format(old_name)) return self def set_target(self, **fields): diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 816db82d..92229d0d 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -182,7 +182,8 @@ class FastNLP(object): if self.infer_type in ["seq_label", "text_class"]: data_set = convert_seq_dataset(infer_input) data_set.index_field("word_seq", self.word_vocab) - data_set.set_origin_len("word_seq") + if self.infer_type == "seq_label": + data_set.set_origin_len("word_seq") return data_set else: raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 4d3674e2..c9e76622 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -77,6 +77,19 @@ class DataSetLoader(BaseLoader): def load(self, path): raise NotImplementedError +class RawDataSetLoader(DataSetLoader): + def __init__(self): + super(RawDataSetLoader, self).__init__() + + def load(self, data_path, split=None): + with open(data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + lines = lines if split is None else [l.split(split) for l in lines] + lines = list(filter(lambda x: len(x) > 0, lines)) + return self.convert(lines) + + def convert(self, data): + return convert_seq_dataset(data) class POSDataSetLoader(DataSetLoader): """Dataset Loader for POS Tag datasets. diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 2fb2c090..84275478 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -56,6 +56,7 @@ class TestPredictor(unittest.TestCase): self.assertTrue(res in class_vocab.word2idx) del model, predictor + infer_data_set.set_origin_len("word_seq") model = SeqLabeling(model_args) predictor = Predictor("./save/", pre.seq_label_post_processor) diff --git a/test/model/test_cws.py b/test/model/test_cws.py index aaadce2d..7f248dce 100644 --- a/test/model/test_cws.py +++ b/test/model/test_cws.py @@ -8,7 +8,7 @@ from fastNLP.core.preprocess import save_pickle, load_pickle from fastNLP.core.tester import SeqLabelTester from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader +from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader from fastNLP.loader.model_loader import ModelLoader from fastNLP.models.sequence_modeling import SeqLabeling from fastNLP.saver.model_saver import ModelSaver @@ -38,9 +38,9 @@ def infer(): print("model loaded!") # Load infer data - infer_data = TokenizeDataSetLoader().load(data_infer_path) + infer_data = RawDataSetLoader().load(data_infer_path) infer_data.index_field("word_seq", word2index) - + infer_data.set_origin_len("word_seq") # inference infer = SeqLabelInfer(pickle_path) results = infer.predict(model, infer_data) @@ -57,6 +57,9 @@ def train_test(): word_vocab = Vocabulary() label_vocab = Vocabulary() data_train.update_vocab(word_seq=word_vocab, label_seq=label_vocab) + data_train.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) + data_train.set_origin_len("word_seq") + data_train.rename_field("label_seq", "truth").set_target(truth=False) train_args["vocab_size"] = len(word_vocab) train_args["num_classes"] = len(label_vocab) diff --git a/test/model/test_seq_label.py b/test/model/test_seq_label.py index ba62b25b..09d43008 100644 --- a/test/model/test_seq_label.py +++ b/test/model/test_seq_label.py @@ -1,6 +1,7 @@ import os -from fastNLP.core.dataset import DataSet +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.loader.dataset_loader import TokenizeDataSetLoader from fastNLP.core.metrics import SeqLabelEvaluator from fastNLP.core.optimizer import Optimizer from fastNLP.core.preprocess import save_pickle @@ -25,14 +26,19 @@ def test_training(): ConfigLoader().load_config(config_dir, { "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) - data_set = DataSet() - word_vocab = V + data_set = TokenizeDataSetLoader().load(data_path) + word_vocab = Vocabulary() + label_vocab = Vocabulary() + data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) + data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) + data_set.set_origin_len("word_seq") + data_set.rename_field("label_seq", "truth").set_target(truth=False) data_train, data_dev = data_set.split(0.3, shuffle=True) - model_args["vocab_size"] = len(data_set.word_vocab) - model_args["num_classes"] = len(data_set.label_vocab) + model_args["vocab_size"] = len(word_vocab) + model_args["num_classes"] = len(label_vocab) - save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") - save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") + save_pickle(word_vocab, pickle_path, "word2id.pkl") + save_pickle(label_vocab, pickle_path, "label2id.pkl") trainer = SeqLabelTrainer( epochs=trainer_args["epochs"],