|
- import unittest
-
- import torch
-
- from fastNLP.modules.encoder.star_transformer import StarTransformer
-
-
- class TestStarTransformer(unittest.TestCase):
- def test_1(self):
- model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100)
- x = torch.rand(16, 45, 100)
- mask = torch.ones(16, 45).byte()
- y, yn = model(x, mask)
- self.assertEqual(tuple(y.size()), (16, 45, 100))
- self.assertEqual(tuple(yn.size()), (16, 100))
|