| @@ -538,6 +538,18 @@ class BertModel(nn.Module): | |||||
| raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | ||||
| model_type = 'BERT' | model_type = 'BERT' | ||||
| old_keys = [] | |||||
| new_keys = [] | |||||
| for key in state_dict.keys(): | |||||
| new_key = None | |||||
| if 'bert' not in key: | |||||
| new_key = 'bert.' + key | |||||
| if new_key: | |||||
| old_keys.append(key) | |||||
| new_keys.append(new_key) | |||||
| for old_key, new_key in zip(old_keys, new_keys): | |||||
| state_dict[new_key] = state_dict.pop(old_key) | |||||
| old_keys = [] | old_keys = [] | ||||
| new_keys = [] | new_keys = [] | ||||
| for key in state_dict.keys(): | for key in state_dict.keys(): | ||||