|
- import os
- import unittest
-
- from fastNLP.core.dataset import DataSet
- from fastNLP.core.predictor import Predictor
- from fastNLP.core.preprocess import save_pickle
- from fastNLP.core.vocabulary import Vocabulary
- from fastNLP.loader.base_loader import BaseLoader
- from fastNLP.loader.dataset_loader import convert_seq_dataset
- from fastNLP.models.cnn_text_classification import CNNText
- from fastNLP.models.sequence_modeling import SeqLabeling
-
-
- class TestPredictor(unittest.TestCase):
- def test_seq_label(self):
- model_args = {
- "vocab_size": 10,
- "word_emb_dim": 100,
- "rnn_hidden_units": 100,
- "num_classes": 5
- }
-
- infer_data = [
- ['a', 'b', 'c', 'd', 'e'],
- ['a', '@', 'c', 'd', 'e'],
- ['a', 'b', '#', 'd', 'e'],
- ['a', 'b', 'c', '?', 'e'],
- ['a', 'b', 'c', 'd', '$'],
- ['!', 'b', 'c', 'd', 'e']
- ]
-
- vocab = Vocabulary()
- vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9}
- class_vocab = Vocabulary()
- class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
-
- os.system("mkdir save")
- save_pickle(class_vocab, "./save/", "label2id.pkl")
- save_pickle(vocab, "./save/", "word2id.pkl")
-
- model = CNNText(model_args)
- import fastNLP.core.predictor as pre
- predictor = Predictor("./save/", pre.text_classify_post_processor)
-
- # Load infer data
- infer_data_set = convert_seq_dataset(infer_data)
- infer_data_set.index_field("word_seq", vocab)
-
- results = predictor.predict(network=model, data=infer_data_set)
-
- self.assertTrue(isinstance(results, list))
- self.assertGreater(len(results), 0)
- self.assertEqual(len(results), len(infer_data))
- for res in results:
- self.assertTrue(isinstance(res, str))
- self.assertTrue(res in class_vocab.word2idx)
-
- del model, predictor
- infer_data_set.set_origin_len("word_seq")
-
- model = SeqLabeling(model_args)
- predictor = Predictor("./save/", pre.seq_label_post_processor)
-
- results = predictor.predict(network=model, data=infer_data_set)
- self.assertTrue(isinstance(results, list))
- self.assertEqual(len(results), len(infer_data))
- for i in range(len(infer_data)):
- res = results[i]
- self.assertTrue(isinstance(res, list))
- self.assertEqual(len(res), len(infer_data[i]))
-
- os.system("rm -rf save")
- print("pickle path deleted")
-
-
- class TestPredictor2(unittest.TestCase):
- def test_text_classify(self):
- # TODO
- pass
|