|
|
@@ -188,6 +188,8 @@ class CallbackManager: |
|
|
|
for each_callback_filters in self._callback_filters: |
|
|
|
if each_callback_filters[0] not in _record_duplicated_callback_names: |
|
|
|
_record_duplicated_callback_names.add(each_callback_filters[0]) |
|
|
|
if 'filter_states' not in states[each_callback_filters[0]]: |
|
|
|
states[each_callback_filters[0]]["filter_states"] = {} |
|
|
|
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() |
|
|
|
|
|
|
|
# 3. 保存 callback_counter; |
|
|
@@ -214,7 +216,9 @@ class CallbackManager: |
|
|
|
if each_callback_filters[0] in states: |
|
|
|
if each_callback_filters[0] not in _already_loaded_callback_names: |
|
|
|
_already_loaded_callback_names.add(each_callback_filters[0]) |
|
|
|
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) |
|
|
|
if 'filter_states' in states[each_callback_filters[0]] and \ |
|
|
|
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: |
|
|
|
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) |
|
|
|
else: |
|
|
|
_duplicated_callback_names.add(each_callback_filters[0]) |
|
|
|
|
|
|
|