@@ -134,7 +134,10 @@ class BasePreprocess(object): | |||
results.append(data_dev) | |||
if test_data: | |||
results.append(data_test) | |||
return tuple(results) | |||
if len(results) == 1: | |||
return results[0] | |||
else: | |||
return tuple(results) | |||
def build_dict(self, data): | |||
raise NotImplementedError | |||
@@ -282,7 +285,8 @@ class ClassPreprocess(BasePreprocess): | |||
data_index = [] | |||
for example in data: | |||
word_list = [] | |||
for word, label in zip(example[0]): | |||
# example[0] is the word list, example[1] is the single label | |||
for word in example[0]: | |||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | |||
data_index.append([word_list, label_index]) | |||
@@ -95,10 +95,10 @@ num_classes = 27 | |||
[text_class] | |||
epochs = 1 | |||
batch_size = 10 | |||
pickle_path = "./data_for_tests/" | |||
pickle_path = "./save_path/" | |||
validate = false | |||
save_best_dev = false | |||
model_saved_path = "./data_for_tests/" | |||
model_saved_path = "./save_path/" | |||
use_cuda = true | |||
learn_rate = 1e-3 | |||
momentum = 0.9 | |||
@@ -14,7 +14,7 @@ from fastNLP.core.predictor import SeqLabelInfer | |||
data_name = "people.txt" | |||
data_path = "data_for_tests/people.txt" | |||
pickle_path = "data_for_tests" | |||
pickle_path = "seq_label/" | |||
data_infer_path = "data_for_tests/people_infer.txt" | |||
@@ -33,21 +33,12 @@ def infer(): | |||
model = SeqLabeling(test_args) | |||
# Dump trained parameters into the model | |||
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||
print("model loaded!") | |||
# Data Loader | |||
raw_data_loader = BaseLoader(data_name, data_infer_path) | |||
infer_data = raw_data_loader.load_lines() | |||
""" | |||
Transform strings into list of list of strings. | |||
[ | |||
[word_11, word_12, ...], | |||
[word_21, word_22, ...], | |||
... | |||
] | |||
In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||
""" | |||
# Inference interface | |||
infer = SeqLabelInfer(pickle_path) | |||
@@ -69,7 +60,7 @@ def train_and_test(): | |||
# Preprocessor | |||
p = SeqLabelPreprocess() | |||
data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) | |||
data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||
train_args["vocab_size"] = p.vocab_size | |||
train_args["num_classes"] = p.num_classes | |||
@@ -84,7 +75,7 @@ def train_and_test(): | |||
print("Training finished!") | |||
# Saver | |||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||
saver = ModelSaver(pickle_path + "saved_model.pkl") | |||
saver.save_pytorch(model) | |||
print("Model saved!") | |||
@@ -94,7 +85,7 @@ def train_and_test(): | |||
model = SeqLabeling(train_args) | |||
# Dump trained parameters into the model | |||
ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
ModelLoader.load_pytorch(model, pickle_path + "saved_model.pkl") | |||
print("model loaded!") | |||
# Load test configuration | |||
@@ -14,6 +14,7 @@ from fastNLP.core.preprocess import ClassPreprocess | |||
from fastNLP.models.cnn_text_classification import CNNText | |||
from fastNLP.saver.model_saver import ModelSaver | |||
save_path = "./test_classification/" | |||
data_dir = "./data_for_tests/" | |||
train_file = 'text_classify.txt' | |||
model_name = "model_class.pkl" | |||
@@ -27,8 +28,8 @@ def infer(): | |||
unlabeled_data = [x[0] for x in data] | |||
# pre-process data | |||
pre = ClassPreprocess(data_dir) | |||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
pre = ClassPreprocess() | |||
vocab_size, n_classes = pre.run(data, pickle_path=save_path) | |||
print("vocabulary size:", vocab_size) | |||
print("number of classes:", n_classes) | |||
@@ -60,7 +61,7 @@ def train(): | |||
# pre-process data | |||
pre = ClassPreprocess() | |||
data_train = pre.run(data, pickle_path=data_dir) | |||
data_train = pre.run(data, pickle_path=save_path) | |||
print("vocabulary size:", pre.vocab_size) | |||
print("number of classes:", pre.num_classes) | |||