From 2b9e09e07af5a40e02f6ce77b192391277c9f73b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 10 May 2022 02:37:01 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dfilter=20state=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f34c5dd3..82b1a756 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -188,6 +188,8 @@ class CallbackManager: for each_callback_filters in self._callback_filters: if each_callback_filters[0] not in _record_duplicated_callback_names: _record_duplicated_callback_names.add(each_callback_filters[0]) + if 'filter_states' not in states[each_callback_filters[0]]: + states[each_callback_filters[0]]["filter_states"] = {} states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() # 3. 保存 callback_counter; @@ -214,7 +216,9 @@ class CallbackManager: if each_callback_filters[0] in states: if each_callback_filters[0] not in _already_loaded_callback_names: _already_loaded_callback_names.add(each_callback_filters[0]) - each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) + if 'filter_states' in states[each_callback_filters[0]] and \ + each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: + each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) else: _duplicated_callback_names.add(each_callback_filters[0])