|
- # Python: 3.5
- # encoding: utf-8
-
- import os
-
- from fastNLP.core.inference import ClassificationInfer
- from fastNLP.core.trainer import ClassificationTrainer
- from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
- from fastNLP.loader.dataset_loader import ClassDatasetLoader
- from fastNLP.loader.model_loader import ModelLoader
- from fastNLP.loader.preprocess import ClassPreprocess
- from fastNLP.models.cnn_text_classification import CNNText
- from fastNLP.saver.model_saver import ModelSaver
-
- data_dir = "./data_for_tests/"
- train_file = 'text_classify.txt'
- model_name = "model_class.pkl"
-
-
- def infer():
- # load dataset
- print("Loading data...")
- ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
- data = ds_loader.load()
- unlabeled_data = [x[0] for x in data]
-
- # pre-process data
- pre = ClassPreprocess(data_dir)
- vocab_size, n_classes = pre.process(data, "data_train.pkl")
- print("vocabulary size:", vocab_size)
- print("number of classes:", n_classes)
-
- model_args = ConfigSection()
- ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})
-
- # construct model
- print("Building model...")
- cnn = CNNText(model_args)
-
- # Dump trained parameters into the model
- ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
- print("model loaded!")
-
- infer = ClassificationInfer(data_dir)
- results = infer.predict(cnn, unlabeled_data)
- print(results)
-
-
- def train():
- train_args, model_args = ConfigSection(), ConfigSection()
- ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})
-
- # load dataset
- print("Loading data...")
- ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
- data = ds_loader.load()
- print(data[0])
-
- # pre-process data
- pre = ClassPreprocess(data_dir)
- vocab_size, n_classes = pre.process(data, "data_train.pkl")
- print("vocabulary size:", vocab_size)
- print("number of classes:", n_classes)
-
- # construct model
- print("Building model...")
- cnn = CNNText(model_args)
-
- # train
- print("Training...")
-
- trainer = ClassificationTrainer(train_args)
- trainer.train(cnn)
-
- print("Training finished!")
-
- saver = ModelSaver("./data_for_tests/saved_model.pkl")
- saver.save_pytorch(cnn)
- print("Model saved!")
-
-
- if __name__ == "__main__":
- # train()
- infer()
|