|
- import pytest
-
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH
-
- if _NEED_IMPORT_TORCH:
- import torch
- from fastNLP.modules.torch.encoder.star_transformer import StarTransformer
-
-
- @pytest.mark.torch
- class TestStarTransformer:
- def test_1(self):
- model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
- x = torch.rand(16, 45, 100)
- mask = torch.ones(16, 45).byte()
- y, yn = model(x, mask)
- assert (tuple(y.size()) == (16, 45, 100))
- assert (tuple(yn.size()) == (16, 100))
|