diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index b0f71f52..8800e1d6 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -3,9 +3,7 @@ __all__ = [ 'Callback', 'Event', 'Filter', - 'CallbackManager', 'CheckpointCallback', - 'choose_progress_callback', 'ProgressCallback', 'RichCallback', "LRSchedCallback", @@ -54,7 +52,6 @@ __all__ = [ 'DataSet', 'FieldArray', 'Instance', - 'ApplyResultException', # drivers "TorchSingleDriver", diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 82b1a756..27770115 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -180,8 +180,8 @@ class CallbackManager: states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) if len(_duplicated_callbacks) > 0: - logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, " - f"and we will only save the first callback's state we meet.") + logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, " + f"fastNLP will only save the first callback's state.") # 2. 每一个具体的 callback 函数的 filter 的状态; _record_duplicated_callback_names = set() @@ -223,8 +223,8 @@ class CallbackManager: _duplicated_callback_names.add(each_callback_filters[0]) if len(_duplicated_callback_names) > 0: - logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, " - f"and we will only load the first callback's state we meet.") + logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, " + f"fastNLP will only load the first callback's state.") # 2. 再恢复每一个 callback 的单独的状态; # 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, @@ -235,8 +235,6 @@ class CallbackManager: _already_loaded_callback_names.add(each_callback.callback_name) # 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) - else: - each_callback.on_load_checkpoint(trainer, None) @property def has_trainer_checkpoint(self) -> bool: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 60dcc862..3105acba 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset +from tests.helpers.utils import Capturing from torchmetrics import Accuracy from fastNLP.core.log import logger @@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1( dist.destroy_process_group() +@pytest.mark.torch +def test_load_state(model_and_optimizers): + try: + path = Path.cwd().joinpath(f"test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) + from fastNLP import Event, Callback + @Trainer.on(Event.on_before_backward(every=3), marker='all') + def print_outputs(*args): + print("????") + + class StateCallback(Callback): + def __init__(self, name): + self.name = name + + def on_save_checkpoint(self, trainer): + return {'name': self.name} + + def on_load_checkpoint(self, trainer, states): + self.name = states['name'] + + def on_train_end(self, trainer): + print(self.name) + + callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'), + CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')] + + trainer = Trainer( + model=model_and_optimizers.model, + driver='torch', + device='cpu', + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=3, + callbacks=callbacks, + output_from_new_proc="all" + ) + trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) + + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + epoch_2_path = all_saved_model_paths['trainer-epoch_2'] + + callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')] + trainer = Trainer( + model=model_and_optimizers.model, + driver='torch', + device='cpu', + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=3, + callbacks=callbacks, + output_from_new_proc="all" + ) + trainer.load(folder=epoch_2_path) + with Capturing() as output: + trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) + + assert 'old_callback1' in output[0] + assert 'new_callback2' in output[0] + assert output[0].count('???')==1 + + finally: + rank_zero_rm(path) + + @pytest.mark.torch # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)