Browse Source

fix bugs in preprocessor

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
743a6d7547
5 changed files with 9 additions and 6 deletions
  1. +1
    -1
      fastNLP/core/tester.py
  2. +4
    -1
      fastNLP/loader/preprocess.py
  3. +1
    -1
      reproduction/chinese_word_seg/cws.cfg
  4. +2
    -2
      reproduction/chinese_word_seg/cws_train.py
  5. +1
    -1
      test/test_cws.py

+ 1
- 1
fastNLP/core/tester.py View File

@@ -64,7 +64,7 @@ class BaseTester(Action):
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). :return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s).
""" """
if self.save_dev_data is None: if self.save_dev_data is None:
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "/data_dev.pkl", "rb"))
self.save_dev_data = data_dev self.save_dev_data = data_dev
return self.save_dev_data return self.save_dev_data




+ 4
- 1
fastNLP/loader/preprocess.py View File

@@ -95,8 +95,11 @@ class POSPreprocess(BasePreprocess):
if not pickle_exist(pickle_path, "data_train.pkl"): if not pickle_exist(pickle_path, "data_train.pkl"):
data_train = self.to_index(data) data_train = self.to_index(data)
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"):
data_dev = data_train[: int(len(data_train) * train_dev_split)]
split = int(len(data_train) * train_dev_split)
data_dev = data_train[: split]
data_train = data_train[split:]
save_pickle(data_dev, self.pickle_path, "data_dev.pkl") save_pickle(data_dev, self.pickle_path, "data_dev.pkl")
print("{} of the training data is split for validation. ".format(train_dev_split))
save_pickle(data_train, self.pickle_path, "data_train.pkl") save_pickle(data_train, self.pickle_path, "data_train.pkl")


def build_dict(self, data): def build_dict(self, data):


+ 1
- 1
reproduction/chinese_word_seg/cws.cfg View File

@@ -1,5 +1,5 @@
[train] [train]
epochs = 2
epochs = 10
batch_size = 32 batch_size = 32
pickle_path = "./save/" pickle_path = "./save/"
validate = true validate = true


+ 2
- 2
reproduction/chinese_word_seg/cws_train.py View File

@@ -15,7 +15,7 @@ from fastNLP.core.inference import Inference
data_name = "pku_training.utf8" data_name = "pku_training.utf8"
cws_data_path = "/home/zyfeng/data/pku_training.utf8" cws_data_path = "/home/zyfeng/data/pku_training.utf8"
pickle_path = "./save/" pickle_path = "./save/"
data_infer_path = "data_for_tests/people_infer.txt"
data_infer_path = "/home/zyfeng/data/pku_test.utf8"




def infer(): def infer():
@@ -59,7 +59,7 @@ def train():
train_data = loader.load_pku() train_data = loader.load_pku()


# Preprocessor # Preprocessor
p = POSPreprocess(train_data, pickle_path)
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3)
train_args["vocab_size"] = p.vocab_size train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes train_args["num_classes"] = p.num_classes




+ 1
- 1
test/test_cws.py View File

@@ -113,4 +113,4 @@ def train_test():


if __name__ == "__main__": if __name__ == "__main__":
train_test() train_test()
infer()
#infer()

Loading…
Cancel
Save