diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py index 0ec0caa0..d19a3577 100644 --- a/fastNLP/embeddings/elmo_embedding.py +++ b/fastNLP/embeddings/elmo_embedding.py @@ -69,6 +69,7 @@ class ElmoEmbedding(ContextualEmbedding): else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) + num_layers = self.model.encoder.num_layers if layers == 'mix': self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), @@ -78,9 +79,9 @@ class ElmoEmbedding(ContextualEmbedding): self._embed_size = self.model.config['lstm']['projection_dim'] * 2 else: layers = list(map(int, layers.split(','))) - assert len(layers) > 0, "Must choose one output" + assert len(layers) > 0, "Must choose at least one output, but got None." for layer in layers: - assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." + assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." self.layers = layers self._get_outputs = self._get_layer_outputs self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 85c3af8c..30ed0cd8 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -241,7 +241,7 @@ class BertForQuestionAnswering(BaseModel): def forward(self, words): """ :param torch.LongTensor words: [batch_size, seq_len] - :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len] + :return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] """ sequence_output = self.bert(words) logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] diff --git a/test/data_for_tests/embedding/small_elmo/char.dic b/test/data_for_tests/embedding/small_elmo/char.dic new file mode 100644 index 00000000..74285f34 --- /dev/null +++ b/test/data_for_tests/embedding/small_elmo/char.dic @@ -0,0 +1,229 @@ +! 33 +" 34 +# 35 +$ 36 +% 37 +& 38 +' 39 +( 40 +) 41 +* 42 ++ 43 +, 44 +- 45 +. 46 +/ 47 +0 48 +1 49 +2 50 +3 51 +4 52 +5 53 +6 54 +7 55 +8 56 +9 57 +: 58 +; 59 +< 60 += 61 +> 62 +? 63 +@ 64 +A 65 +B 66 +C 67 +D 68 +E 69 +F 70 +G 71 +H 72 +I 73 +J 74 +K 75 +L 76 +M 77 +N 78 +O 79 +P 80 +Q 81 +R 82 +S 83 +T 84 +U 85 +V 86 +W 87 +X 88 +Y 89 +Z 90 +[ 91 +\ 92 +] 93 +^ 94 +_ 95 +` 96 +a 97 +b 98 +c 99 +d 100 +e 101 +f 102 +g 103 +h 104 +i 105 +j 106 +k 107 +l 108 +m 109 +n 110 +o 111 +p 112 +q 113 +r 114 +s 115 +t 116 +u 117 +v 118 +w 119 +x 120 +y 121 +z 122 +{ 123 +| 124 +} 125 +~ 126 + 127 +€ 128 + 129 +‚ 130 +ƒ 131 +„ 132 +† 134 +‡ 135 +ˆ 136 +‰ 137 +Š 138 +‹ 139 +Œ 140 + 141 +Ž 142 + 143 + 144 +‘ 145 +’ 146 +“ 147 +” 148 +• 149 +– 150 +— 151 +˜ 152 +™ 153 +š 154 +› 155 +œ 156 + 157 +ž 158 +Ÿ 159 +  160 +¡ 161 +¢ 162 +£ 163 +¤ 164 +¥ 165 +¦ 166 +§ 167 +¨ 168 +© 169 +ª 170 +« 171 +¬ 172 +­ 173 +® 174 +¯ 175 +° 176 +± 177 +² 178 +³ 179 +´ 180 +µ 181 +¶ 182 +· 183 +¸ 184 +¹ 185 +º 186 +» 187 +¼ 188 +½ 189 +¾ 190 +¿ 191 +À 192 +Á 193 + 194 +à 195 +Ä 196 +Å 197 +Æ 198 +Ç 199 +È 200 +É 201 +Ê 202 +Ë 203 +Ì 204 +Í 205 +Î 206 +Ï 207 +Ð 208 +Ñ 209 +Ò 210 +Ó 211 +Ô 212 +Õ 213 +Ö 214 +× 215 +Ø 216 +Ù 217 +Ú 218 +Û 219 +Ü 220 +Ý 221 +Þ 222 +ß 223 +à 224 +á 225 +â 226 +ã 227 +ä 228 +å 229 +æ 230 +ç 231 +è 232 +é 233 +ê 234 +ë 235 +ì 236 +í 237 +î 238 +ï 239 +ð 240 +ñ 241 +ò 242 +ó 243 +ô 244 +õ 245 +ö 246 +÷ 247 +ø 248 +ù 249 +ú 250 +û 251 +ü 252 +ý 253 +þ 254 +ÿ 255 + 256 + 257 + 258 + 259 + 260 + 1 + -1 diff --git a/test/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json b/test/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json new file mode 100644 index 00000000..9c02ef72 --- /dev/null +++ b/test/data_for_tests/embedding/small_elmo/elmo_1x16_16_32cnn_1xhighway_options.json @@ -0,0 +1,29 @@ +{ + "lstm": { + "use_skip_connections": true, + "projection_dim": 16, + "cell_clip": 3, + "proj_clip": 3, + "dim": 16, + "n_layers": 1 + }, + "char_cnn": { + "activation": "relu", + "filters": [ + [ + 1, + 16 + ], + [ + 2, + 16 + ] + ], + "n_highway": 1, + "embedding": { + "dim": 4 + }, + "n_characters": 262, + "max_characters_per_token": 50 + } +} diff --git a/test/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl b/test/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl new file mode 100644 index 00000000..4c72f3d5 Binary files /dev/null and b/test/data_for_tests/embedding/small_elmo/elmo_mini_for_testing.pkl differ diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py index 6a4a0ffa..71511458 100644 --- a/test/embeddings/test_bert_embedding.py +++ b/test/embeddings/test_bert_embedding.py @@ -29,8 +29,11 @@ class TestDownload(unittest.TestCase): class TestBertEmbedding(unittest.TestCase): def test_bert_embedding_1(self): - vocab = Vocabulary().add_word_lst("this is a test .".split()) - embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert') + vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) + embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) + requires_grad = embed.requires_grad + embed.requires_grad = not requires_grad + embed.train() words = torch.LongTensor([[2, 3, 4, 0]]) result = embed(words) self.assertEqual(result.size(), (1, 4, 16)) diff --git a/test/embeddings/test_elmo_embedding.py b/test/embeddings/test_elmo_embedding.py index a087f0a4..bfb31659 100644 --- a/test/embeddings/test_elmo_embedding.py +++ b/test/embeddings/test_elmo_embedding.py @@ -18,4 +18,19 @@ class TestDownload(unittest.TestCase): # 首先保证所有权重可以加载;上传权重;验证可以下载 +class TestRunElmo(unittest.TestCase): + def test_elmo_embedding(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1') + words = torch.LongTensor([[0, 1, 2]]) + hidden = elmo_embed(words) + print(hidden.size()) + + def test_elmo_embedding_layer_assertion(self): + vocab = Vocabulary().add_word_lst("This is a test .".split()) + try: + elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', + layers='0,1,2') + except AssertionError as e: + print(e)