diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f34c5dd3..82b1a756 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -188,6 +188,8 @@ class CallbackManager: for each_callback_filters in self._callback_filters: if each_callback_filters[0] not in _record_duplicated_callback_names: _record_duplicated_callback_names.add(each_callback_filters[0]) + if 'filter_states' not in states[each_callback_filters[0]]: + states[each_callback_filters[0]]["filter_states"] = {} states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() # 3. 保存 callback_counter; @@ -214,7 +216,9 @@ class CallbackManager: if each_callback_filters[0] in states: if each_callback_filters[0] not in _already_loaded_callback_names: _already_loaded_callback_names.add(each_callback_filters[0]) - each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) + if 'filter_states' in states[each_callback_filters[0]] and \ + each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: + each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) else: _duplicated_callback_names.add(each_callback_filters[0])