import unittest import torch from fastNLP.modules.encoder.char_embedding import ConvCharEmbedding, LSTMCharEmbedding class TestCharEmbed(unittest.TestCase): def test_case_1(self): batch_size = 128 char_emb = 100 word_length = 1 x = torch.Tensor(batch_size, char_emb, word_length) x = x.transpose(1, 2) cce = ConvCharEmbedding(char_emb) y = cce(x) self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) print("CNN Char Emb input: ", x.shape) self.assertEqual(tuple(y.shape), (batch_size, char_emb, 1)) print("CNN Char Emb output: ", y.shape) # [128, 100] lce = LSTMCharEmbedding(char_emb) o = lce(x) self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb)) print("LSTM Char Emb input: ", x.shape) self.assertEqual(tuple(o.shape), (batch_size, char_emb, 1)) print("LSTM Char Emb size: ", o.shape)