Browse Source

small

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
f1cb6f6167
1 changed files with 2 additions and 1 deletions
  1. +2
    -1
      tests/modules/torch/encoder/test_seq2seq_encoder.py

+ 2
- 1
tests/modules/torch/encoder/test_seq2seq_encoder.py View File

@@ -1,12 +1,12 @@
import pytest import pytest


from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP import Vocabulary


if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch


from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder from fastNLP.modules.torch.encoder.seq2seq_encoder import TransformerSeq2SeqEncoder, LSTMSeq2SeqEncoder
from fastNLP import Vocabulary
from fastNLP.embeddings.torch import StaticEmbedding from fastNLP.embeddings.torch import StaticEmbedding




@@ -22,6 +22,7 @@ class TestTransformerSeq2SeqEncoder:
assert (encoder_output.size() == (1, 3, 10)) assert (encoder_output.size() == (1, 3, 10))




@pytest.mark.torch
class TestBiLSTMEncoder: class TestBiLSTMEncoder:
def test_case(self): def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab = Vocabulary().add_word_lst("This is a test .".split())


Loading…
Cancel
Save