diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index b22f32ef..e45f1017 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -64,7 +64,7 @@ class BaseTester(Action): :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). """ if self.save_dev_data is None: - data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) + data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb")) self.save_dev_data = data_dev return self.save_dev_data diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index fd378ba0..2c972ddd 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -95,8 +95,11 @@ class POSPreprocess(BasePreprocess): if not pickle_exist(pickle_path, "data_train.pkl"): data_train = self.to_index(data) if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): - data_dev = data_train[: int(len(data_train) * train_dev_split)] + split = int(len(data_train) * train_dev_split) + data_dev = data_train[: split] + data_train = data_train[split:] save_pickle(data_dev, self.pickle_path, "data_dev.pkl") + print("{} of the training data is split for validation. ".format(train_dev_split)) save_pickle(data_train, self.pickle_path, "data_train.pkl") def build_dict(self, data): diff --git a/reproduction/chinese_word_seg/cws.cfg b/reproduction/chinese_word_seg/cws.cfg index ded4f623..cdcb4496 100644 --- a/reproduction/chinese_word_seg/cws.cfg +++ b/reproduction/chinese_word_seg/cws.cfg @@ -1,5 +1,5 @@ [train] -epochs = 2 +epochs = 10 batch_size = 32 pickle_path = "./save/" validate = true diff --git a/reproduction/chinese_word_seg/cws_train.py b/reproduction/chinese_word_seg/cws_train.py index 6616ff5f..ff549eb9 100644 --- a/reproduction/chinese_word_seg/cws_train.py +++ b/reproduction/chinese_word_seg/cws_train.py @@ -15,7 +15,7 @@ from fastNLP.core.inference import Inference data_name = "pku_training.utf8" cws_data_path = "/home/zyfeng/data/pku_training.utf8" pickle_path = "./save/" -data_infer_path = "data_for_tests/people_infer.txt" +data_infer_path = "/home/zyfeng/data/pku_test.utf8" def infer(): @@ -59,7 +59,7 @@ def train(): train_data = loader.load_pku() # Preprocessor - p = POSPreprocess(train_data, pickle_path) + p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3) train_args["vocab_size"] = p.vocab_size train_args["num_classes"] = p.num_classes diff --git a/test/test_cws.py b/test/test_cws.py index 9d8f973c..8f6c1211 100644 --- a/test/test_cws.py +++ b/test/test_cws.py @@ -113,4 +113,4 @@ def train_test(): if __name__ == "__main__": train_test() - infer() + #infer()