@@ -64,7 +64,7 @@ class BaseTester(Action): | |||||
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). | ||||
""" | """ | ||||
if self.save_dev_data is None: | if self.save_dev_data is None: | ||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb")) | |||||
self.save_dev_data = data_dev | self.save_dev_data = data_dev | ||||
return self.save_dev_data | return self.save_dev_data | ||||
@@ -95,8 +95,11 @@ class POSPreprocess(BasePreprocess): | |||||
if not pickle_exist(pickle_path, "data_train.pkl"): | if not pickle_exist(pickle_path, "data_train.pkl"): | ||||
data_train = self.to_index(data) | data_train = self.to_index(data) | ||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | ||||
data_dev = data_train[: int(len(data_train) * train_dev_split)] | |||||
split = int(len(data_train) * train_dev_split) | |||||
data_dev = data_train[: split] | |||||
data_train = data_train[split:] | |||||
save_pickle(data_dev, self.pickle_path, "data_dev.pkl") | save_pickle(data_dev, self.pickle_path, "data_dev.pkl") | ||||
print("{} of the training data is split for validation. ".format(train_dev_split)) | |||||
save_pickle(data_train, self.pickle_path, "data_train.pkl") | save_pickle(data_train, self.pickle_path, "data_train.pkl") | ||||
def build_dict(self, data): | def build_dict(self, data): | ||||
@@ -1,5 +1,5 @@ | |||||
[train] | [train] | ||||
epochs = 2 | |||||
epochs = 10 | |||||
batch_size = 32 | batch_size = 32 | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
validate = true | validate = true | ||||
@@ -15,7 +15,7 @@ from fastNLP.core.inference import Inference | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
cws_data_path = "/home/zyfeng/data/pku_training.utf8" | cws_data_path = "/home/zyfeng/data/pku_training.utf8" | ||||
pickle_path = "./save/" | pickle_path = "./save/" | ||||
data_infer_path = "data_for_tests/people_infer.txt" | |||||
data_infer_path = "/home/zyfeng/data/pku_test.utf8" | |||||
def infer(): | def infer(): | ||||
@@ -59,7 +59,7 @@ def train(): | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path) | |||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -113,4 +113,4 @@ def train_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_test() | train_test() | ||||
infer() | |||||
#infer() |