diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 821b9c5c..6ff03fa2 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -538,6 +538,18 @@ class BertModel(nn.Module): raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') model_type = 'BERT' + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'bert' not in key: + new_key = 'bert.' + key + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + old_keys = [] new_keys = [] for key in state_dict.keys():