|
-
- import unittest
-
- from .model_runner import *
- from fastNLP.models.cnn_text_classification import CNNText
-
-
- class TestCNNText(unittest.TestCase):
- def init_model(self, kernel_sizes, kernel_nums=(1,3,5)):
- model = CNNText((VOCAB_SIZE, 30),
- NUM_CLS,
- kernel_nums=kernel_nums,
- kernel_sizes=kernel_sizes)
- return model
-
- def test_case1(self):
- # 测试能否正常运行CNN
- model = self.init_model((1,3,5))
- RUNNER.run_model_with_task(TEXT_CLS, model)
-
- def test_init_model(self):
- self.assertRaises(Exception, self.init_model, (2,4))
- self.assertRaises(Exception, self.init_model, (2,))
-
- def test_output(self):
- model = self.init_model((3,), (1,))
- global MAX_LEN
- MAX_LEN = 2
- RUNNER.run_model_with_task(TEXT_CLS, model)
|