| @@ -188,6 +188,8 @@ class CallbackManager: | |||||
| for each_callback_filters in self._callback_filters: | for each_callback_filters in self._callback_filters: | ||||
| if each_callback_filters[0] not in _record_duplicated_callback_names: | if each_callback_filters[0] not in _record_duplicated_callback_names: | ||||
| _record_duplicated_callback_names.add(each_callback_filters[0]) | _record_duplicated_callback_names.add(each_callback_filters[0]) | ||||
| if 'filter_states' not in states[each_callback_filters[0]]: | |||||
| states[each_callback_filters[0]]["filter_states"] = {} | |||||
| states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | ||||
| # 3. 保存 callback_counter; | # 3. 保存 callback_counter; | ||||
| @@ -214,7 +216,9 @@ class CallbackManager: | |||||
| if each_callback_filters[0] in states: | if each_callback_filters[0] in states: | ||||
| if each_callback_filters[0] not in _already_loaded_callback_names: | if each_callback_filters[0] not in _already_loaded_callback_names: | ||||
| _already_loaded_callback_names.add(each_callback_filters[0]) | _already_loaded_callback_names.add(each_callback_filters[0]) | ||||
| each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) | |||||
| if 'filter_states' in states[each_callback_filters[0]] and \ | |||||
| each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: | |||||
| each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) | |||||
| else: | else: | ||||
| _duplicated_callback_names.add(each_callback_filters[0]) | _duplicated_callback_names.add(each_callback_filters[0]) | ||||