Browse Source

fix bugs and clean up

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
fac830e1cd
4 changed files with 17 additions and 21 deletions
  1. +6
    -2
      fastNLP/core/preprocess.py
  2. +2
    -2
      test/data_for_tests/config
  3. +5
    -14
      test/seq_labeling.py
  4. +4
    -3
      test/text_classify.py

+ 6
- 2
fastNLP/core/preprocess.py View File

@@ -134,7 +134,10 @@ class BasePreprocess(object):
results.append(data_dev)
if test_data:
results.append(data_test)
return tuple(results)
if len(results) == 1:
return results[0]
else:
return tuple(results)

def build_dict(self, data):
raise NotImplementedError
@@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess):
data_index = []
for example in data:
word_list = []
for word, label in zip(example[0]):
# example[0] is the word list, example[1] is the single label
for word in example[0]:
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]))
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])
data_index.append([word_list, label_index])


+ 2
- 2
test/data_for_tests/config View File

@@ -95,10 +95,10 @@ num_classes = 27
[text_class]
epochs = 1
batch_size = 10
pickle_path = "./data_for_tests/"
pickle_path = "./save_path/"
validate = false
save_best_dev = false
model_saved_path = "./data_for_tests/"
model_saved_path = "./save_path/"
use_cuda = true
learn_rate = 1e-3
momentum = 0.9


+ 5
- 14
test/seq_labeling.py View File

@@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer

data_name = "people.txt"
data_path = "data_for_tests/people.txt"
pickle_path = "data_for_tests"
pickle_path = "seq_label/"
data_infer_path = "data_for_tests/people_infer.txt"


@@ -33,21 +33,12 @@ def infer():
model = SeqLabeling(test_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl")
print("model loaded!")

# Data Loader
raw_data_loader = BaseLoader(data_name, data_infer_path)
infer_data = raw_data_loader.load_lines()
"""
Transform strings into list of list of strings.
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them.
"""

# Inference interface
infer = SeqLabelInfer(pickle_path)
@@ -69,7 +60,7 @@ def train_and_test():

# Preprocessor
p = SeqLabelPreprocess()
data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5)
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes

@@ -84,7 +75,7 @@ def train_and_test():
print("Training finished!")

# Saver
saver = ModelSaver("./data_for_tests/saved_model.pkl")
saver = ModelSaver(pickle_path + "saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")

@@ -94,7 +85,7 @@ def train_and_test():
model = SeqLabeling(train_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl")
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl")
print("model loaded!")

# Load test configuration


+ 4
- 3
test/text_classify.py View File

@@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.saver.model_saver import ModelSaver

save_path = "./test_classification/"
data_dir = "./data_for_tests/"
train_file = 'text_classify.txt'
model_name = "model_class.pkl"
@@ -27,8 +28,8 @@ def infer():
unlabeled_data = [x[0] for x in data]

# pre-process data
pre = ClassPreprocess(data_dir)
vocab_size, n_classes = pre.process(data, "data_train.pkl")
pre = ClassPreprocess()
vocab_size, n_classes = pre.run(data, pickle_path=save_path)
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

@@ -60,7 +61,7 @@ def train():

# pre-process data
pre = ClassPreprocess()
data_train = pre.run(data, pickle_path=data_dir)
data_train = pre.run(data, pickle_path=save_path)
print("vocabulary size:", pre.vocab_size)
print("number of classes:", pre.num_classes)



Loading…
Cancel
Save