You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_char_embedding.py 970 B

12345678910111213141516171819202122232425262728
  1. import unittest
  2. import torch
  3. from fastNLP.modules.encoder.char_embedding import ConvCharEmbedding, LSTMCharEmbedding
  4. class TestCharEmbed(unittest.TestCase):
  5. def test_case_1(self):
  6. batch_size = 128
  7. char_emb = 100
  8. word_length = 1
  9. x = torch.Tensor(batch_size, char_emb, word_length)
  10. x = x.transpose(1, 2)
  11. cce = ConvCharEmbedding(char_emb)
  12. y = cce(x)
  13. self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb))
  14. print("CNN Char Emb input: ", x.shape)
  15. self.assertEqual(tuple(y.shape), (batch_size, char_emb, 1))
  16. print("CNN Char Emb output: ", y.shape) # [128, 100]
  17. lce = LSTMCharEmbedding(char_emb)
  18. o = lce(x)
  19. self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb))
  20. print("LSTM Char Emb input: ", x.shape)
  21. self.assertEqual(tuple(o.shape), (batch_size, char_emb, 1))
  22. print("LSTM Char Emb size: ", o.shape)