From fac830e1cd4bdad4fa7146e63efb97cfdeaeec1a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Mon, 20 Aug 2018 19:25:19 +0800 Subject: [PATCH] fix bugs and clean up --- fastNLP/core/preprocess.py | 8 ++++++-- test/data_for_tests/config | 4 ++-- test/seq_labeling.py | 19 +++++-------------- test/text_classify.py | 7 ++++--- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index 6b81bff1..dfaf3e94 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -134,7 +134,10 @@ class BasePreprocess(object): results.append(data_dev) if test_data: results.append(data_test) - return tuple(results) + if len(results) == 1: + return results[0] + else: + return tuple(results) def build_dict(self, data): raise NotImplementedError @@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess): data_index = [] for example in data: word_list = [] - for word, label in zip(example[0]): + # example[0] is the word list, example[1] is the single label + for word in example[0]: word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) data_index.append([word_list, label_index]) diff --git a/test/data_for_tests/config b/test/data_for_tests/config index 60a7c9a5..2ffdcf3b 100644 --- a/test/data_for_tests/config +++ b/test/data_for_tests/config @@ -95,10 +95,10 @@ num_classes = 27 [text_class] epochs = 1 batch_size = 10 -pickle_path = "./data_for_tests/" +pickle_path = "./save_path/" validate = false save_best_dev = false -model_saved_path = "./data_for_tests/" +model_saved_path = "./save_path/" use_cuda = true learn_rate = 1e-3 momentum = 0.9 diff --git a/test/seq_labeling.py b/test/seq_labeling.py index fe67b79c..b4007092 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer data_name = "people.txt" data_path = "data_for_tests/people.txt" -pickle_path = "data_for_tests" +pickle_path = "seq_label/" data_infer_path = "data_for_tests/people_infer.txt" @@ -33,21 +33,12 @@ def infer(): model = SeqLabeling(test_args) # Dump trained parameters into the model - ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") + ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") print("model loaded!") # Data Loader raw_data_loader = BaseLoader(data_name, data_infer_path) infer_data = raw_data_loader.load_lines() - """ - Transform strings into list of list of strings. - [ - [word_11, word_12, ...], - [word_21, word_22, ...], - ... - ] - In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. - """ # Inference interface infer = SeqLabelInfer(pickle_path) @@ -69,7 +60,7 @@ def train_and_test(): # Preprocessor p = SeqLabelPreprocess() - data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) + data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) train_args["vocab_size"] = p.vocab_size train_args["num_classes"] = p.num_classes @@ -84,7 +75,7 @@ def train_and_test(): print("Training finished!") # Saver - saver = ModelSaver("./data_for_tests/saved_model.pkl") + saver = ModelSaver(pickle_path + "saved_model.pkl") saver.save_pytorch(model) print("Model saved!") @@ -94,7 +85,7 @@ def train_and_test(): model = SeqLabeling(train_args) # Dump trained parameters into the model - ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") + ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") print("model loaded!") # Load test configuration diff --git a/test/text_classify.py b/test/text_classify.py index d6a77781..c452e86c 100644 --- a/test/text_classify.py +++ b/test/text_classify.py @@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess from fastNLP.models.cnn_text_classification import CNNText from fastNLP.saver.model_saver import ModelSaver +save_path = "./test_classification/" data_dir = "./data_for_tests/" train_file = 'text_classify.txt' model_name = "model_class.pkl" @@ -27,8 +28,8 @@ def infer(): unlabeled_data = [x[0] for x in data] # pre-process data - pre = ClassPreprocess(data_dir) - vocab_size, n_classes = pre.process(data, "data_train.pkl") + pre = ClassPreprocess() + vocab_size, n_classes = pre.run(data, pickle_path=save_path) print("vocabulary size:", vocab_size) print("number of classes:", n_classes) @@ -60,7 +61,7 @@ def train(): # pre-process data pre = ClassPreprocess() - data_train = pre.run(data, pickle_path=data_dir) + data_train = pre.run(data, pickle_path=save_path) print("vocabulary size:", pre.vocab_size) print("number of classes:", pre.num_classes)