@@ -69,6 +69,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | |||
num_layers = self.model.encoder.num_layers | |||
if layers == 'mix': | |||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | |||
@@ -78,9 +79,9 @@ class ElmoEmbedding(ContextualEmbedding): | |||
self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | |||
else: | |||
layers = list(map(int, layers.split(','))) | |||
assert len(layers) > 0, "Must choose one output" | |||
assert len(layers) > 0, "Must choose at least one output, but got None." | |||
for layer in layers: | |||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||
assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." | |||
self.layers = layers | |||
self._get_outputs = self._get_layer_outputs | |||
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | |||
@@ -241,7 +241,7 @@ class BertForQuestionAnswering(BaseModel): | |||
def forward(self, words): | |||
""" | |||
:param torch.LongTensor words: [batch_size, seq_len] | |||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len] | |||
:return: 一个包含num_labels个logit的dict,每一个logit的形状都是[batch_size, seq_len + 2] | |||
""" | |||
sequence_output = self.bert(words) | |||
logits = self.qa_outputs(sequence_output) # [batch_size, seq_len, num_labels] | |||
@@ -0,0 +1,229 @@ | |||
! 33 | |||
" 34 | |||
# 35 | |||
$ 36 | |||
% 37 | |||
& 38 | |||
' 39 | |||
( 40 | |||
) 41 | |||
* 42 | |||
+ 43 | |||
, 44 | |||
- 45 | |||
. 46 | |||
/ 47 | |||
0 48 | |||
1 49 | |||
2 50 | |||
3 51 | |||
4 52 | |||
5 53 | |||
6 54 | |||
7 55 | |||
8 56 | |||
9 57 | |||
: 58 | |||
; 59 | |||
< 60 | |||
= 61 | |||
> 62 | |||
? 63 | |||
@ 64 | |||
A 65 | |||
B 66 | |||
C 67 | |||
D 68 | |||
E 69 | |||
F 70 | |||
G 71 | |||
H 72 | |||
I 73 | |||
J 74 | |||
K 75 | |||
L 76 | |||
M 77 | |||
N 78 | |||
O 79 | |||
P 80 | |||
Q 81 | |||
R 82 | |||
S 83 | |||
T 84 | |||
U 85 | |||
V 86 | |||
W 87 | |||
X 88 | |||
Y 89 | |||
Z 90 | |||
[ 91 | |||
\ 92 | |||
] 93 | |||
^ 94 | |||
_ 95 | |||
` 96 | |||
a 97 | |||
b 98 | |||
c 99 | |||
d 100 | |||
e 101 | |||
f 102 | |||
g 103 | |||
h 104 | |||
i 105 | |||
j 106 | |||
k 107 | |||
l 108 | |||
m 109 | |||
n 110 | |||
o 111 | |||
p 112 | |||
q 113 | |||
r 114 | |||
s 115 | |||
t 116 | |||
u 117 | |||
v 118 | |||
w 119 | |||
x 120 | |||
y 121 | |||
z 122 | |||
{ 123 | |||
| 124 | |||
} 125 | |||
~ 126 | |||
127 | |||
128 | |||
129 | |||
130 | |||
131 | |||
132 | |||
134 | |||
135 | |||
136 | |||
137 | |||
138 | |||
139 | |||
140 | |||
141 | |||
142 | |||
143 | |||
144 | |||
145 | |||
146 | |||
147 | |||
148 | |||
149 | |||
150 | |||
151 | |||
152 | |||
153 | |||
154 | |||
155 | |||
156 | |||
157 | |||
158 | |||
159 | |||
160 | |||
¡ 161 | |||
¢ 162 | |||
£ 163 | |||
¤ 164 | |||
¥ 165 | |||
¦ 166 | |||
§ 167 | |||
¨ 168 | |||
© 169 | |||
ª 170 | |||
« 171 | |||
¬ 172 | |||
173 | |||
® 174 | |||
¯ 175 | |||
° 176 | |||
± 177 | |||
² 178 | |||
³ 179 | |||
´ 180 | |||
µ 181 | |||
¶ 182 | |||
· 183 | |||
¸ 184 | |||
¹ 185 | |||
º 186 | |||
» 187 | |||
¼ 188 | |||
½ 189 | |||
¾ 190 | |||
¿ 191 | |||
À 192 | |||
Á 193 | |||
 194 | |||
à 195 | |||
Ä 196 | |||
Å 197 | |||
Æ 198 | |||
Ç 199 | |||
È 200 | |||
É 201 | |||
Ê 202 | |||
Ë 203 | |||
Ì 204 | |||
Í 205 | |||
Î 206 | |||
Ï 207 | |||
Ð 208 | |||
Ñ 209 | |||
Ò 210 | |||
Ó 211 | |||
Ô 212 | |||
Õ 213 | |||
Ö 214 | |||
× 215 | |||
Ø 216 | |||
Ù 217 | |||
Ú 218 | |||
Û 219 | |||
Ü 220 | |||
Ý 221 | |||
Þ 222 | |||
ß 223 | |||
à 224 | |||
á 225 | |||
â 226 | |||
ã 227 | |||
ä 228 | |||
å 229 | |||
æ 230 | |||
ç 231 | |||
è 232 | |||
é 233 | |||
ê 234 | |||
ë 235 | |||
ì 236 | |||
í 237 | |||
î 238 | |||
ï 239 | |||
ð 240 | |||
ñ 241 | |||
ò 242 | |||
ó 243 | |||
ô 244 | |||
õ 245 | |||
ö 246 | |||
÷ 247 | |||
ø 248 | |||
ù 249 | |||
ú 250 | |||
û 251 | |||
ü 252 | |||
ý 253 | |||
þ 254 | |||
ÿ 255 | |||
<bos> 256 | |||
<eos> 257 | |||
<bow> 258 | |||
<eow> 259 | |||
<char_pad> 260 | |||
<oov> 1 | |||
<pad> -1 |
@@ -0,0 +1,29 @@ | |||
{ | |||
"lstm": { | |||
"use_skip_connections": true, | |||
"projection_dim": 16, | |||
"cell_clip": 3, | |||
"proj_clip": 3, | |||
"dim": 16, | |||
"n_layers": 1 | |||
}, | |||
"char_cnn": { | |||
"activation": "relu", | |||
"filters": [ | |||
[ | |||
1, | |||
16 | |||
], | |||
[ | |||
2, | |||
16 | |||
] | |||
], | |||
"n_highway": 1, | |||
"embedding": { | |||
"dim": 4 | |||
}, | |||
"n_characters": 262, | |||
"max_characters_per_token": 50 | |||
} | |||
} |
@@ -29,8 +29,11 @@ class TestDownload(unittest.TestCase): | |||
class TestBertEmbedding(unittest.TestCase): | |||
def test_bert_embedding_1(self): | |||
vocab = Vocabulary().add_word_lst("this is a test .".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert') | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
requires_grad = embed.requires_grad | |||
embed.requires_grad = not requires_grad | |||
embed.train() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) |
@@ -18,4 +18,19 @@ class TestDownload(unittest.TestCase): | |||
# 首先保证所有权重可以加载;上传权重;验证可以下载 | |||
class TestRunElmo(unittest.TestCase): | |||
def test_elmo_embedding(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', layers='0,1') | |||
words = torch.LongTensor([[0, 1, 2]]) | |||
hidden = elmo_embed(words) | |||
print(hidden.size()) | |||
def test_elmo_embedding_layer_assertion(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
try: | |||
elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_elmo', | |||
layers='0,1,2') | |||
except AssertionError as e: | |||
print(e) | |||