|
- import sys
- sys.path.append("..")
-
- from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
- from fastNLP.action.trainer import POSTrainer
- from fastNLP.loader.dataset_loader import POSDatasetLoader
- from fastNLP.loader.preprocess import POSPreprocess
- from fastNLP.saver.model_saver import ModelSaver
- from fastNLP.loader.model_loader import ModelLoader
- from fastNLP.action.tester import POSTester
- from fastNLP.models.sequence_modeling import SeqLabeling
-
- data_name = "people.txt"
- data_path = "data_for_tests/people.txt"
- pickle_path = "data_for_tests"
-
- if __name__ == "__main__":
- train_args = ConfigSection()
- ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args})
-
- # Data Loader
- pos_loader = POSDatasetLoader(data_name, data_path)
- train_data = pos_loader.load_lines()
-
- # Preprocessor
- p = POSPreprocess(train_data, pickle_path)
- vocab_size = p.vocab_size
- num_classes = p.num_classes
-
- train_args["vocab_size"] = vocab_size
- train_args["num_classes"] = num_classes
-
- trainer = POSTrainer(train_args)
-
- # Model
- model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
-
- # Start training
- trainer.train(model)
-
- print("Training finished!")
-
- saver = ModelSaver("./saved_model.pkl")
- saver.save_pytorch(model)
- print("Model saved!")
-
- del model, trainer, pos_loader
-
- model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True)
- ModelLoader("xxx", "./saved_model.pkl").load_pytorch(model)
- print("model loaded!")
-
- test_args = ConfigSection()
- ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
-
- # test_args = {"save_output": True, "validate_in_training": False, "save_dev_input": False,
- # "save_loss": True, "batch_size": 1, "pickle_path": pickle_path}
- tester = POSTester(test_args)
- tester.test(model)
- print("model tested!")
|