@@ -134,7 +134,10 @@ class BasePreprocess(object): | |||||
results.append(data_dev) | results.append(data_dev) | ||||
if test_data: | if test_data: | ||||
results.append(data_test) | results.append(data_test) | ||||
return tuple(results) | |||||
if len(results) == 1: | |||||
return results[0] | |||||
else: | |||||
return tuple(results) | |||||
def build_dict(self, data): | def build_dict(self, data): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess): | |||||
data_index = [] | data_index = [] | ||||
for example in data: | for example in data: | ||||
word_list = [] | word_list = [] | ||||
for word, label in zip(example[0]): | |||||
# example[0] is the word list, example[1] is the single label | |||||
for word in example[0]: | |||||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | ||||
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | ||||
data_index.append([word_list, label_index]) | data_index.append([word_list, label_index]) | ||||
@@ -95,10 +95,10 @@ num_classes = 27 | |||||
[text_class] | [text_class] | ||||
epochs = 1 | epochs = 1 | ||||
batch_size = 10 | batch_size = 10 | ||||
pickle_path = "./data_for_tests/" | |||||
pickle_path = "./save_path/" | |||||
validate = false | validate = false | ||||
save_best_dev = false | save_best_dev = false | ||||
model_saved_path = "./data_for_tests/" | |||||
model_saved_path = "./save_path/" | |||||
use_cuda = true | use_cuda = true | ||||
learn_rate = 1e-3 | learn_rate = 1e-3 | ||||
momentum = 0.9 | momentum = 0.9 | ||||
@@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer | |||||
data_name = "people.txt" | data_name = "people.txt" | ||||
data_path = "data_for_tests/people.txt" | data_path = "data_for_tests/people.txt" | ||||
pickle_path = "data_for_tests" | |||||
pickle_path = "seq_label/" | |||||
data_infer_path = "data_for_tests/people_infer.txt" | data_infer_path = "data_for_tests/people_infer.txt" | ||||
@@ -33,21 +33,12 @@ def infer(): | |||||
model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_name, data_infer_path) | raw_data_loader = BaseLoader(data_name, data_infer_path) | ||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
""" | |||||
Transform strings into list of list of strings. | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||||
""" | |||||
# Inference interface | # Inference interface | ||||
infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
@@ -69,7 +60,7 @@ def train_and_test(): | |||||
# Preprocessor | # Preprocessor | ||||
p = SeqLabelPreprocess() | p = SeqLabelPreprocess() | ||||
data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) | |||||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -84,7 +75,7 @@ def train_and_test(): | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
saver = ModelSaver(pickle_path + "saved_model.pkl") | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
@@ -94,7 +85,7 @@ def train_and_test(): | |||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
# Dump trained parameters into the model | # Dump trained parameters into the model | ||||
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||||
print("model loaded!") | print("model loaded!") | ||||
# Load test configuration | # Load test configuration | ||||
@@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess | |||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
save_path = "./test_classification/" | |||||
data_dir = "./data_for_tests/" | data_dir = "./data_for_tests/" | ||||
train_file = 'text_classify.txt' | train_file = 'text_classify.txt' | ||||
model_name = "model_class.pkl" | model_name = "model_class.pkl" | ||||
@@ -27,8 +28,8 @@ def infer(): | |||||
unlabeled_data = [x[0] for x in data] | unlabeled_data = [x[0] for x in data] | ||||
# pre-process data | # pre-process data | ||||
pre = ClassPreprocess(data_dir) | |||||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||||
pre = ClassPreprocess() | |||||
vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||||
print("vocabulary size:", vocab_size) | print("vocabulary size:", vocab_size) | ||||
print("number of classes:", n_classes) | print("number of classes:", n_classes) | ||||
@@ -60,7 +61,7 @@ def train(): | |||||
# pre-process data | # pre-process data | ||||
pre = ClassPreprocess() | pre = ClassPreprocess() | ||||
data_train = pre.run(data, pickle_path=data_dir) | |||||
data_train = pre.run(data, pickle_path=save_path) | |||||
print("vocabulary size:", pre.vocab_size) | print("vocabulary size:", pre.vocab_size) | ||||
print("number of classes:", pre.num_classes) | print("number of classes:", pre.num_classes) | ||||