|
|
@@ -1,6 +1,7 @@ |
|
|
|
import os |
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader |
|
|
|
from fastNLP.core.metrics import SeqLabelEvaluator |
|
|
|
from fastNLP.core.optimizer import Optimizer |
|
|
|
from fastNLP.core.preprocess import save_pickle |
|
|
@@ -25,14 +26,19 @@ def test_training(): |
|
|
|
ConfigLoader().load_config(config_dir, { |
|
|
|
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) |
|
|
|
|
|
|
|
data_set = DataSet() |
|
|
|
word_vocab = V |
|
|
|
data_set = TokenizeDataSetLoader().load(data_path) |
|
|
|
word_vocab = Vocabulary() |
|
|
|
label_vocab = Vocabulary() |
|
|
|
data_set.update_vocab(word_seq=word_vocab, label_seq=label_vocab) |
|
|
|
data_set.index_field("word_seq", word_vocab).index_field("label_seq", label_vocab) |
|
|
|
data_set.set_origin_len("word_seq") |
|
|
|
data_set.rename_field("label_seq", "truth").set_target(truth=False) |
|
|
|
data_train, data_dev = data_set.split(0.3, shuffle=True) |
|
|
|
model_args["vocab_size"] = len(data_set.word_vocab) |
|
|
|
model_args["num_classes"] = len(data_set.label_vocab) |
|
|
|
model_args["vocab_size"] = len(word_vocab) |
|
|
|
model_args["num_classes"] = len(label_vocab) |
|
|
|
|
|
|
|
save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") |
|
|
|
save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") |
|
|
|
save_pickle(word_vocab, pickle_path, "word2id.pkl") |
|
|
|
save_pickle(label_vocab, pickle_path, "label2id.pkl") |
|
|
|
|
|
|
|
trainer = SeqLabelTrainer( |
|
|
|
epochs=trainer_args["epochs"], |
|
|
|