From 65b141c117c2e8f1968baf9d921d931cb79f33d0 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Sat, 14 Mar 2020 21:03:00 +0800 Subject: [PATCH] Update bert.py --- fastNLP/modules/encoder/bert.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 821b9c5c..6ff03fa2 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -538,6 +538,18 @@ class BertModel(nn.Module): raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') model_type = 'BERT' + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'bert' not in key: + new_key = 'bert.' + key + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + old_keys = [] new_keys = [] for key in state_dict.keys():