Browse Source

修复filter state bug

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
2b9e09e07a
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      fastNLP/core/callbacks/callback_manager.py

+ 5
- 1
fastNLP/core/callbacks/callback_manager.py View File

@@ -188,6 +188,8 @@ class CallbackManager:
for each_callback_filters in self._callback_filters:
if each_callback_filters[0] not in _record_duplicated_callback_names:
_record_duplicated_callback_names.add(each_callback_filters[0])
if 'filter_states' not in states[each_callback_filters[0]]:
states[each_callback_filters[0]]["filter_states"] = {}
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict()

# 3. 保存 callback_counter;
@@ -214,7 +216,9 @@ class CallbackManager:
if each_callback_filters[0] in states:
if each_callback_filters[0] not in _already_loaded_callback_names:
_already_loaded_callback_names.add(each_callback_filters[0])
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]])
if 'filter_states' in states[each_callback_filters[0]] and \
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']:
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]])
else:
_duplicated_callback_names.add(each_callback_filters[0])



Loading…
Cancel
Save