|
|
@@ -538,6 +538,18 @@ class BertModel(nn.Module): |
|
|
|
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') |
|
|
|
|
|
|
|
model_type = 'BERT' |
|
|
|
old_keys = [] |
|
|
|
new_keys = [] |
|
|
|
for key in state_dict.keys(): |
|
|
|
new_key = None |
|
|
|
if 'bert' not in key: |
|
|
|
new_key = 'bert.' + key |
|
|
|
if new_key: |
|
|
|
old_keys.append(key) |
|
|
|
new_keys.append(new_key) |
|
|
|
for old_key, new_key in zip(old_keys, new_keys): |
|
|
|
state_dict[new_key] = state_dict.pop(old_key) |
|
|
|
|
|
|
|
old_keys = [] |
|
|
|
new_keys = [] |
|
|
|
for key in state_dict.keys(): |
|
|
|