|
- import pytest
-
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH
- from fastNLP import Vocabulary
-
- if _NEED_IMPORT_TORCH:
- import torch
-
- from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
- from fastNLP.embeddings.torch import StaticEmbedding
-
-
- @pytest.mark.torch
- class TestTransformerSeq2SeqEncoder:
- def test_case(self):
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- embed = StaticEmbedding(vocab, embedding_dim=5)
- encoder = TransformerSeq2SeqEncoder(embed, num_layers=2, d_model=10, n_head=2)
- words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
- seq_len = torch.LongTensor([3])
- encoder_output, encoder_mask = encoder(words_idx, seq_len)
- assert (encoder_output.size() == (1, 3, 10))
-
-
- @pytest.mark.torch
- class TestBiLSTMEncoder:
- def test_case(self):
- vocab = Vocabulary().add_word_lst("This is a test .".split())
- embed = StaticEmbedding(vocab, embedding_dim=5)
- encoder = LSTMSeq2SeqEncoder(embed, hidden_size=5, num_layers=1)
- words_idx = torch.LongTensor([0, 1, 2]).unsqueeze(0)
- seq_len = torch.LongTensor([3])
-
- encoder_output, encoder_mask = encoder(words_idx, seq_len)
- assert (encoder_mask.size() == (1, 3))
|