|
|
@@ -1,12 +1,12 @@ |
|
|
|
import pytest |
|
|
|
|
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH |
|
|
|
from fastNLP import Vocabulary |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder |
|
|
|
from fastNLP import Vocabulary |
|
|
|
from fastNLP.embeddings.torch import StaticEmbedding |
|
|
|
|
|
|
|
|
|
|
@@ -22,6 +22,7 @@ class TestTransformerSeq2SeqEncoder: |
|
|
|
assert (encoder_output.size() == (1, 3, 10)) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
class TestBiLSTMEncoder: |
|
|
|
def test_case(self): |
|
|
|
vocab = Vocabulary().add_word_lst("This is a test .".split()) |
|
|
|