|
- import pytest
- from .model_runner import *
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH
- if _NEED_IMPORT_TORCH:
- from fastNLP.models.torch.sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
-
-
- @pytest.mark.torch
- class TestBiLSTM:
- def test_case1(self):
- # 测试能否正常运行CNN
- init_emb = (VOCAB_SIZE, 30)
- model = BiLSTMCRF(init_emb,
- hidden_size=30,
- num_classes=NUM_CLS)
-
- dl = RUNNER.prepare_pos_tagging_data()
- metric = Accuracy()
- RUNNER.run_model(model, dl, metric)
-
-
- @pytest.mark.torch
- class TestSeqLabel:
- def test_case1(self):
- # 测试能否正常运行CNN
- init_emb = (VOCAB_SIZE, 30)
- model = SeqLabeling(init_emb,
- hidden_size=30,
- num_classes=NUM_CLS)
-
- dl = RUNNER.prepare_pos_tagging_data()
- metric = Accuracy()
- RUNNER.run_model(model, dl, metric)
-
-
- @pytest.mark.torch
- class TestAdvSeqLabel:
- def test_case1(self):
- # 测试能否正常运行CNN
- init_emb = (VOCAB_SIZE, 30)
- model = AdvSeqLabel(init_emb,
- hidden_size=30,
- num_classes=NUM_CLS)
-
- dl = RUNNER.prepare_pos_tagging_data()
- metric = Accuracy()
- RUNNER.run_model(model, dl, metric)
|