|
- import os
- import unittest
-
- from fastNLP.core.predictor import Predictor
- from fastNLP.core.preprocess import save_pickle
- from fastNLP.models.sequence_modeling import SeqLabeling
-
-
- class TestPredictor(unittest.TestCase):
- def test_seq_label(self):
- model_args = {
- "vocab_size": 10,
- "word_emb_dim": 100,
- "rnn_hidden_units": 100,
- "num_classes": 5
- }
-
- infer_data = [
- ['a', 'b', 'c', 'd', 'e'],
- ['a', '@', 'c', 'd', 'e'],
- ['a', 'b', '#', 'd', 'e'],
- ['a', 'b', 'c', '?', 'e'],
- ['a', 'b', 'c', 'd', '$'],
- ['!', 'b', 'c', 'd', 'e']
- ]
- vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
-
- os.system("mkdir save")
- save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl")
- save_pickle(vocab, "./save/", "word2id.pkl")
-
- model = SeqLabeling(model_args)
- predictor = Predictor("./save/", task="seq_label")
-
- results = predictor.predict(network=model, data=infer_data)
-
- self.assertTrue(isinstance(results, list))
- self.assertGreater(len(results), 0)
- for res in results:
- self.assertTrue(isinstance(res, list))
- self.assertEqual(len(res), 5)
- self.assertTrue(isinstance(res[0], str))
-
- os.system("rm -rf save")
- print("pickle path deleted")
-
-
- class TestPredictor2(unittest.TestCase):
- def test_text_classify(self):
- # TODO
- pass
|