Browse Source

Update bert.py

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
65b141c117
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      fastNLP/modules/encoder/bert.py

+ 12
- 0
fastNLP/modules/encoder/bert.py View File

@@ -538,6 +538,18 @@ class BertModel(nn.Module):
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')

model_type = 'BERT'
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'bert' not in key:
new_key = 'bert.' + key
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

old_keys = []
new_keys = []
for key in state_dict.keys():


Loading…
Cancel
Save