diff --git a/tests/modules/torch/encoder/test_seq2seq_encoder.py b/tests/modules/torch/encoder/test_seq2seq_encoder.py index 97aa5a7c..3570fd16 100755 --- a/tests/modules/torch/encoder/test_seq2seq_encoder.py +++ b/tests/modules/torch/encoder/test_seq2seq_encoder.py @@ -1,12 +1,12 @@ import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP import Vocabulary if _NEED_IMPORT_TORCH: import torch from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder - from fastNLP import Vocabulary from fastNLP.embeddings.torch import StaticEmbedding @@ -22,6 +22,7 @@ class TestTransformerSeq2SeqEncoder: assert (encoder_output.size() == (1, 3, 10)) +@pytest.mark.torch class TestBiLSTMEncoder: def test_case(self): vocab = Vocabulary().add_word_lst("This is a test .".split())