|
|
@@ -4,6 +4,7 @@ import unittest |
|
|
|
from fastNLP.core.predictor import Predictor |
|
|
|
from fastNLP.core.preprocess import save_pickle |
|
|
|
from fastNLP.models.sequence_modeling import SeqLabeling |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
class TestPredictor(unittest.TestCase): |
|
|
@@ -23,10 +24,14 @@ class TestPredictor(unittest.TestCase): |
|
|
|
['a', 'b', 'c', 'd', '$'], |
|
|
|
['!', 'b', 'c', 'd', 'e'] |
|
|
|
] |
|
|
|
vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} |
|
|
|
|
|
|
|
vocab = Vocabulary() |
|
|
|
vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} |
|
|
|
class_vocab = Vocabulary() |
|
|
|
class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} |
|
|
|
|
|
|
|
os.system("mkdir save") |
|
|
|
save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") |
|
|
|
save_pickle(class_vocab, "./save/", "class2id.pkl") |
|
|
|
save_pickle(vocab, "./save/", "word2id.pkl") |
|
|
|
|
|
|
|
model = SeqLabeling(model_args) |
|
|
|