Browse Source

新增Callback on_load_checkpoint测试

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
7763b2e087
3 changed files with 77 additions and 9 deletions
  1. +0
    -3
      fastNLP/core/__init__.py
  2. +4
    -6
      fastNLP/core/callbacks/callback_manager.py
  3. +73
    -0
      tests/core/callbacks/test_checkpoint_callback_torch.py

+ 0
- 3
fastNLP/core/__init__.py View File

@@ -3,9 +3,7 @@ __all__ = [
'Callback', 'Callback',
'Event', 'Event',
'Filter', 'Filter',
'CallbackManager',
'CheckpointCallback', 'CheckpointCallback',
'choose_progress_callback',
'ProgressCallback', 'ProgressCallback',
'RichCallback', 'RichCallback',
"LRSchedCallback", "LRSchedCallback",
@@ -54,7 +52,6 @@ __all__ = [
'DataSet', 'DataSet',
'FieldArray', 'FieldArray',
'Instance', 'Instance',
'ApplyResultException',


# drivers # drivers
"TorchSingleDriver", "TorchSingleDriver",


+ 4
- 6
fastNLP/core/callbacks/callback_manager.py View File

@@ -180,8 +180,8 @@ class CallbackManager:
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer)


if len(_duplicated_callbacks) > 0: if len(_duplicated_callbacks) > 0:
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, "
f"and we will only save the first callback's state we meet.")
logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, "
f"fastNLP will only save the first callback's state.")


# 2. 每一个具体的 callback 函数的 filter 的状态; # 2. 每一个具体的 callback 函数的 filter 的状态;
_record_duplicated_callback_names = set() _record_duplicated_callback_names = set()
@@ -223,8 +223,8 @@ class CallbackManager:
_duplicated_callback_names.add(each_callback_filters[0]) _duplicated_callback_names.add(each_callback_filters[0])


if len(_duplicated_callback_names) > 0: if len(_duplicated_callback_names) > 0:
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, "
f"and we will only load the first callback's state we meet.")
logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, "
f"fastNLP will only load the first callback's state.")


# 2. 再恢复每一个 callback 的单独的状态; # 2. 再恢复每一个 callback 的单独的状态;
# 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, # 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样,
@@ -235,8 +235,6 @@ class CallbackManager:
_already_loaded_callback_names.add(each_callback.callback_name) _already_loaded_callback_names.add(each_callback.callback_name)
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; # 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态;
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"])
else:
each_callback.on_load_checkpoint(trainer, None)


@property @property
def has_trainer_checkpoint(self) -> bool: def has_trainer_checkpoint(self) -> bool:


+ 73
- 0
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context
from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.distributed import rank_zero_rm
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from tests.helpers.utils import Capturing
from torchmetrics import Accuracy from torchmetrics import Accuracy
from fastNLP.core.log import logger from fastNLP.core.log import logger


@@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1(
dist.destroy_process_group() dist.destroy_process_group()




@pytest.mark.torch
def test_load_state(model_and_optimizers):
try:
path = Path.cwd().joinpath(f"test_model_checkpoint")
path.mkdir(exist_ok=True, parents=True)
from fastNLP import Event, Callback
@Trainer.on(Event.on_before_backward(every=3), marker='all')
def print_outputs(*args):
print("????")

class StateCallback(Callback):
def __init__(self, name):
self.name = name

def on_save_checkpoint(self, trainer):
return {'name': self.name}

def on_load_checkpoint(self, trainer, states):
self.name = states['name']

def on_train_end(self, trainer):
print(self.name)

callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'),
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')]

trainer = Trainer(
model=model_and_optimizers.model,
driver='torch',
device='cpu',
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=3,
callbacks=callbacks,
output_from_new_proc="all"
)
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)

all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()}
epoch_2_path = all_saved_model_paths['trainer-epoch_2']

callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')]
trainer = Trainer(
model=model_and_optimizers.model,
driver='torch',
device='cpu',
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=3,
callbacks=callbacks,
output_from_new_proc="all"
)
trainer.load(folder=epoch_2_path)
with Capturing() as output:
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2)

assert 'old_callback1' in output[0]
assert 'new_callback2' in output[0]
assert output[0].count('???')==1

finally:
rank_zero_rm(path)


@pytest.mark.torch @pytest.mark.torch
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载;
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)


Loading…
Cancel
Save