@@ -3,9 +3,7 @@ __all__ = [ | |||
'Callback', | |||
'Event', | |||
'Filter', | |||
'CallbackManager', | |||
'CheckpointCallback', | |||
'choose_progress_callback', | |||
'ProgressCallback', | |||
'RichCallback', | |||
"LRSchedCallback", | |||
@@ -54,7 +52,6 @@ __all__ = [ | |||
'DataSet', | |||
'FieldArray', | |||
'Instance', | |||
'ApplyResultException', | |||
# drivers | |||
"TorchSingleDriver", | |||
@@ -180,14 +180,16 @@ class CallbackManager: | |||
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) | |||
if len(_duplicated_callbacks) > 0: | |||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, " | |||
f"and we will only save the first callback's state we meet.") | |||
logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, " | |||
f"fastNLP will only save the first callback's state.") | |||
# 2. 每一个具体的 callback 函数的 filter 的状态; | |||
_record_duplicated_callback_names = set() | |||
for each_callback_filters in self._callback_filters: | |||
if each_callback_filters[0] not in _record_duplicated_callback_names: | |||
_record_duplicated_callback_names.add(each_callback_filters[0]) | |||
if 'filter_states' not in states[each_callback_filters[0]]: | |||
states[each_callback_filters[0]]["filter_states"] = {} | |||
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | |||
# 3. 保存 callback_counter; | |||
@@ -214,13 +216,15 @@ class CallbackManager: | |||
if each_callback_filters[0] in states: | |||
if each_callback_filters[0] not in _already_loaded_callback_names: | |||
_already_loaded_callback_names.add(each_callback_filters[0]) | |||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) | |||
if 'filter_states' in states[each_callback_filters[0]] and \ | |||
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: | |||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) | |||
else: | |||
_duplicated_callback_names.add(each_callback_filters[0]) | |||
if len(_duplicated_callback_names) > 0: | |||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, " | |||
f"and we will only load the first callback's state we meet.") | |||
logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, " | |||
f"fastNLP will only load the first callback's state.") | |||
# 2. 再恢复每一个 callback 的单独的状态; | |||
# 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, | |||
@@ -231,8 +235,6 @@ class CallbackManager: | |||
_already_loaded_callback_names.add(each_callback.callback_name) | |||
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; | |||
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) | |||
else: | |||
each_callback.on_load_checkpoint(trainer, None) | |||
@property | |||
def has_trainer_checkpoint(self) -> bool: | |||
@@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.envs.distributed import rank_zero_rm | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||
from tests.helpers.utils import Capturing | |||
from torchmetrics import Accuracy | |||
from fastNLP.core.log import logger | |||
@@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1( | |||
dist.destroy_process_group() | |||
@pytest.mark.torch | |||
def test_load_state(model_and_optimizers): | |||
try: | |||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||
path.mkdir(exist_ok=True, parents=True) | |||
from fastNLP import Event, Callback | |||
@Trainer.on(Event.on_before_backward(every=3), marker='all') | |||
def print_outputs(*args): | |||
print("????") | |||
class StateCallback(Callback): | |||
def __init__(self, name): | |||
self.name = name | |||
def on_save_checkpoint(self, trainer): | |||
return {'name': self.name} | |||
def on_load_checkpoint(self, trainer, states): | |||
self.name = states['name'] | |||
def on_train_end(self, trainer): | |||
print(self.name) | |||
callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'), | |||
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver='torch', | |||
device='cpu', | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=3, | |||
callbacks=callbacks, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | |||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||
epoch_2_path = all_saved_model_paths['trainer-epoch_2'] | |||
callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver='torch', | |||
device='cpu', | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=3, | |||
callbacks=callbacks, | |||
output_from_new_proc="all" | |||
) | |||
trainer.load(folder=epoch_2_path) | |||
with Capturing() as output: | |||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | |||
assert 'old_callback1' in output[0] | |||
assert 'new_callback2' in output[0] | |||
assert output[0].count('???')==1 | |||
finally: | |||
rank_zero_rm(path) | |||
@pytest.mark.torch | |||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | |||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||