@@ -69,6 +69,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | ||||
num_layers = self.model.encoder.num_layers | |||||
if layers == 'mix': | if layers == 'mix': | ||||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | ||||
@@ -78,9 +79,9 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | ||||
else: | else: | ||||
layers = list(map(int, layers.split(','))) | layers = list(map(int, layers.split(','))) | ||||
assert len(layers) > 0, "Must choose one output" | |||||
assert len(layers) > 0, "Must choose at least one output, but got None." | |||||
for layer in layers: | for layer in layers: | ||||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||||
assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." | |||||
self.layers = layers | self.layers = layers | ||||
self._get_outputs = self._get_layer_outputs | self._get_outputs = self._get_layer_outputs | ||||
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | ||||
@@ -241,7 +241,7 @@ class BertForQuestionAnswering(BaseModel): | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
:param torch.LongTensor words: [batch_size, seq_len] | :param torch.LongTensor words: [batch_size, seq_len] | ||||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len] | |||||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | |||||
""" | """ | ||||
sequence_output = self.bert(words) | sequence_output = self.bert(words) | ||||
logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | ||||
@@ -0,0 +1,229 @@ | |||||
! 33 | |||||
" 34 | |||||
# 35 | |||||
$ 36 | |||||
% 37 | |||||
& 38 | |||||
' 39 | |||||
( 40 | |||||
) 41 | |||||
* 42 | |||||
+ 43 | |||||
, 44 | |||||
- 45 | |||||
. 46 | |||||
/ 47 | |||||
0 48 | |||||
1 49 | |||||
2 50 | |||||
3 51 | |||||
4 52 | |||||
5 53 | |||||
6 54 | |||||
7 55 | |||||
8 56 | |||||
9 57 | |||||
: 58 | |||||
; 59 | |||||
< 60 | |||||
= 61 | |||||
> 62 | |||||
? 63 | |||||
@ 64 | |||||
A 65 | |||||
B 66 | |||||
C 67 | |||||
D 68 | |||||
E 69 | |||||
F 70 | |||||
G 71 | |||||
H 72 | |||||
I 73 | |||||
J 74 | |||||
K 75 | |||||
L 76 | |||||
M 77 | |||||
N 78 | |||||
O 79 | |||||
P 80 | |||||
Q 81 | |||||
R 82 | |||||
S 83 | |||||
T 84 | |||||
U 85 | |||||
V 86 | |||||
W 87 | |||||
X 88 | |||||
Y 89 | |||||
Z 90 | |||||
[ 91 | |||||
\ 92 | |||||
] 93 | |||||
^ 94 | |||||
_ 95 | |||||
` 96 | |||||
a 97 | |||||
b 98 | |||||
c 99 | |||||
d 100 | |||||
e 101 | |||||
f 102 | |||||
g 103 | |||||
h 104 | |||||
i 105 | |||||
j 106 | |||||
k 107 | |||||
l 108 | |||||
m 109 | |||||
n 110 | |||||
o 111 | |||||
p 112 | |||||
q 113 | |||||
r 114 | |||||
s 115 | |||||
t 116 | |||||
u 117 | |||||
v 118 | |||||
w 119 | |||||
x 120 | |||||
y 121 | |||||
z 122 | |||||
{ 123 | |||||
| 124 | |||||
} 125 | |||||
~ 126 | |||||
127 | |||||
128 | |||||
129 | |||||
130 | |||||
131 | |||||
132 | |||||
134 | |||||
135 | |||||
136 | |||||
137 | |||||
138 | |||||
139 | |||||
140 | |||||
141 | |||||
142 | |||||
143 | |||||
144 | |||||
145 | |||||
146 | |||||
147 | |||||
148 | |||||
149 | |||||
150 | |||||
151 | |||||
152 | |||||
153 | |||||
154 | |||||
155 | |||||
156 | |||||
157 | |||||
158 | |||||
159 | |||||
160 | |||||
¡ 161 | |||||
¢ 162 | |||||
£ 163 | |||||
¤ 164 | |||||
¥ 165 | |||||
¦ 166 | |||||
§ 167 | |||||
¨ 168 | |||||
© 169 | |||||
ª 170 | |||||
« 171 | |||||
¬ 172 | |||||
173 | |||||
® 174 | |||||
¯ 175 | |||||
° 176 | |||||
± 177 | |||||
² 178 | |||||
³ 179 | |||||
´ 180 | |||||
µ 181 | |||||
¶ 182 | |||||
· 183 | |||||
¸ 184 | |||||
¹ 185 | |||||
º 186 | |||||
» 187 | |||||
¼ 188 | |||||
½ 189 | |||||
¾ 190 | |||||
¿ 191 | |||||
À 192 | |||||
Á 193 | |||||
 194 | |||||
à 195 | |||||
Ä 196 | |||||
Å 197 | |||||
Æ 198 | |||||
Ç 199 | |||||
È 200 | |||||
É 201 | |||||
Ê 202 | |||||
Ë 203 | |||||
Ì 204 | |||||
Í 205 | |||||
Î 206 | |||||
Ï 207 | |||||
Ð 208 | |||||
Ñ 209 | |||||
Ò 210 | |||||
Ó 211 | |||||
Ô 212 | |||||
Õ 213 | |||||
Ö 214 | |||||
× 215 | |||||
Ø 216 | |||||
Ù 217 | |||||
Ú 218 | |||||
Û 219 | |||||
Ü 220 | |||||
Ý 221 | |||||
Þ 222 | |||||
ß 223 | |||||
à 224 | |||||
á 225 | |||||
â 226 | |||||
ã 227 | |||||
ä 228 | |||||
å 229 | |||||
æ 230 | |||||
ç 231 | |||||
è 232 | |||||
é 233 | |||||
ê 234 | |||||
ë 235 | |||||
ì 236 | |||||
í 237 | |||||
î 238 | |||||
ï 239 | |||||
ð 240 | |||||
ñ 241 | |||||
ò 242 | |||||
ó 243 | |||||
ô 244 | |||||
õ 245 | |||||
ö 246 | |||||
÷ 247 | |||||
ø 248 | |||||
ù 249 | |||||
ú 250 | |||||
û 251 | |||||
ü 252 | |||||
ý 253 | |||||
þ 254 | |||||
ÿ 255 | |||||
<bos> 256 | |||||
<eos> 257 | |||||
<bow> 258 | |||||
<eow> 259 | |||||
<char_pad> 260 | |||||
<oov> 1 | |||||
<pad> -1 |
@@ -0,0 +1,29 @@ | |||||
{ | |||||
"lstm": { | |||||
"use_skip_connections": true, | |||||
"projection_dim": 16, | |||||
"cell_clip": 3, | |||||
"proj_clip": 3, | |||||
"dim": 16, | |||||
"n_layers": 1 | |||||
}, | |||||
"char_cnn": { | |||||
"activation": "relu", | |||||
"filters": [ | |||||
[ | |||||
1, | |||||
16 | |||||
], | |||||
[ | |||||
2, | |||||
16 | |||||
] | |||||
], | |||||
"n_highway": 1, | |||||
"embedding": { | |||||
"dim": 4 | |||||
}, | |||||
"n_characters": 262, | |||||
"max_characters_per_token": 50 | |||||
} | |||||
} |
@@ -29,8 +29,11 @@ class TestDownload(unittest.TestCase): | |||||
class TestBertEmbedding(unittest.TestCase): | class TestBertEmbedding(unittest.TestCase): | ||||
def test_bert_embedding_1(self): | def test_bert_embedding_1(self): | ||||
vocab = Vocabulary().add_word_lst("this is a test .".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert') | |||||
vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) | |||||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||||
requires_grad = embed.requires_grad | |||||
embed.requires_grad = not requires_grad | |||||
embed.train() | |||||
words = torch.LongTensor([[2, 3, 4, 0]]) | words = torch.LongTensor([[2, 3, 4, 0]]) | ||||
result = embed(words) | result = embed(words) | ||||
self.assertEqual(result.size(), (1, 4, 16)) | self.assertEqual(result.size(), (1, 4, 16)) |
@@ -18,4 +18,19 @@ class TestDownload(unittest.TestCase): | |||||
# 首先保证所有权重可以加载;上传权重;验证可以下载 | # 首先保证所有权重可以加载;上传权重;验证可以下载 | ||||
class TestRunElmo(unittest.TestCase): | |||||
def test_elmo_embedding(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1') | |||||
words = torch.LongTensor([[0, 1, 2]]) | |||||
hidden = elmo_embed(words) | |||||
print(hidden.size()) | |||||
def test_elmo_embedding_layer_assertion(self): | |||||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||||
try: | |||||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', | |||||
layers='0,1,2') | |||||
except AssertionError as e: | |||||
print(e) | |||||