| @@ -1,12 +1,12 @@ | |||||
| import pytest | import pytest | ||||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
| from fastNLP import Vocabulary | |||||
| if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
| import torch | import torch | ||||
| from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder | ||||
| from fastNLP import Vocabulary | |||||
| from fastNLP.embeddings.torch import StaticEmbedding | from fastNLP.embeddings.torch import StaticEmbedding | ||||
| @@ -22,6 +22,7 @@ class TestTransformerSeq2SeqEncoder: | |||||
| assert (encoder_output.size() == (1, 3, 10)) | assert (encoder_output.size() == (1, 3, 10)) | ||||
| @pytest.mark.torch | |||||
| class TestBiLSTMEncoder: | class TestBiLSTMEncoder: | ||||
| def test_case(self): | def test_case(self): | ||||
| vocab = Vocabulary().add_word_lst("This is a test .".split()) | vocab = Vocabulary().add_word_lst("This is a test .".split()) | ||||