@@ -3,9 +3,7 @@ __all__ = [ | |||||
'Callback', | 'Callback', | ||||
'Event', | 'Event', | ||||
'Filter', | 'Filter', | ||||
'CallbackManager', | |||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
'choose_progress_callback', | |||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
"LRSchedCallback", | "LRSchedCallback", | ||||
@@ -54,7 +52,6 @@ __all__ = [ | |||||
'DataSet', | 'DataSet', | ||||
'FieldArray', | 'FieldArray', | ||||
'Instance', | 'Instance', | ||||
'ApplyResultException', | |||||
# drivers | # drivers | ||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
@@ -180,14 +180,16 @@ class CallbackManager: | |||||
states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) | states[each_callback.callback_name]["states"] = each_callback.on_save_checkpoint(trainer) | ||||
if len(_duplicated_callbacks) > 0: | if len(_duplicated_callbacks) > 0: | ||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callbacks}, " | |||||
f"and we will only save the first callback's state we meet.") | |||||
logger.warning(f"Notice these callback_name: {_duplicated_callbacks} are duplicated, " | |||||
f"fastNLP will only save the first callback's state.") | |||||
# 2. 每一个具体的 callback 函数的 filter 的状态; | # 2. 每一个具体的 callback 函数的 filter 的状态; | ||||
_record_duplicated_callback_names = set() | _record_duplicated_callback_names = set() | ||||
for each_callback_filters in self._callback_filters: | for each_callback_filters in self._callback_filters: | ||||
if each_callback_filters[0] not in _record_duplicated_callback_names: | if each_callback_filters[0] not in _record_duplicated_callback_names: | ||||
_record_duplicated_callback_names.add(each_callback_filters[0]) | _record_duplicated_callback_names.add(each_callback_filters[0]) | ||||
if 'filter_states' not in states[each_callback_filters[0]]: | |||||
states[each_callback_filters[0]]["filter_states"] = {} | |||||
states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]] = each_callback_filters[2].state_dict() | ||||
# 3. 保存 callback_counter; | # 3. 保存 callback_counter; | ||||
@@ -214,13 +216,15 @@ class CallbackManager: | |||||
if each_callback_filters[0] in states: | if each_callback_filters[0] in states: | ||||
if each_callback_filters[0] not in _already_loaded_callback_names: | if each_callback_filters[0] not in _already_loaded_callback_names: | ||||
_already_loaded_callback_names.add(each_callback_filters[0]) | _already_loaded_callback_names.add(each_callback_filters[0]) | ||||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]["filter_states"][each_callback_filters[1]]) | |||||
if 'filter_states' in states[each_callback_filters[0]] and \ | |||||
each_callback_filters[1] in states[each_callback_filters[0]]['filter_states']: | |||||
each_callback_filters[2].load_state_dict(states[each_callback_filters[0]]['filter_states'][each_callback_filters[1]]) | |||||
else: | else: | ||||
_duplicated_callback_names.add(each_callback_filters[0]) | _duplicated_callback_names.add(each_callback_filters[0]) | ||||
if len(_duplicated_callback_names) > 0: | if len(_duplicated_callback_names) > 0: | ||||
logger.warning(f"Notice these callbacks' `callback_name` are duplicated: {_duplicated_callback_names}, " | |||||
f"and we will only load the first callback's state we meet.") | |||||
logger.rank_zero_warning(f"Notice these callback_name: {_duplicated_callback_names} are duplicated, " | |||||
f"fastNLP will only load the first callback's state.") | |||||
# 2. 再恢复每一个 callback 的单独的状态; | # 2. 再恢复每一个 callback 的单独的状态; | ||||
# 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, | # 每一个我们自己提供的类 callback,都需要重写其特定的 `callback_name` 方法,保证如果两个 callback 的 callback_name 一样, | ||||
@@ -231,8 +235,6 @@ class CallbackManager: | |||||
_already_loaded_callback_names.add(each_callback.callback_name) | _already_loaded_callback_names.add(each_callback.callback_name) | ||||
# 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; | # 这里要注意,我们已经确保每一个 callback 的 `on_load_checkpoint` 函数拿到的就是其自己的状态; | ||||
each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) | each_callback.on_load_checkpoint(trainer, states[each_callback.callback_name]["states"]) | ||||
else: | |||||
each_callback.on_load_checkpoint(trainer, None) | |||||
@property | @property | ||||
def has_trainer_checkpoint(self) -> bool: | def has_trainer_checkpoint(self) -> bool: | ||||
@@ -14,6 +14,7 @@ from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.envs.distributed import rank_zero_rm | from fastNLP.envs.distributed import rank_zero_rm | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from tests.helpers.utils import Capturing | |||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -428,6 +429,78 @@ def test_trainer_checkpoint_callback_1( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
def test_load_state(model_and_optimizers): | |||||
try: | |||||
path = Path.cwd().joinpath(f"test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
from fastNLP import Event, Callback | |||||
@Trainer.on(Event.on_before_backward(every=3), marker='all') | |||||
def print_outputs(*args): | |||||
print("????") | |||||
class StateCallback(Callback): | |||||
def __init__(self, name): | |||||
self.name = name | |||||
def on_save_checkpoint(self, trainer): | |||||
return {'name': self.name} | |||||
def on_load_checkpoint(self, trainer, states): | |||||
self.name = states['name'] | |||||
def on_train_end(self, trainer): | |||||
print(self.name) | |||||
callbacks = [StateCallback('old_callback1'), StateCallback('old_callback2'), | |||||
CheckpointCallback(folder=path, every_n_epochs=1, save_object='trainer')] | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver='torch', | |||||
device='cpu', | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=3, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
epoch_2_path = all_saved_model_paths['trainer-epoch_2'] | |||||
callbacks = [StateCallback('new_callback1'), StateCallback('new_callback2')] | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver='torch', | |||||
device='cpu', | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=3, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.load(folder=epoch_2_path) | |||||
with Capturing() as output: | |||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | |||||
assert 'old_callback1' in output[0] | |||||
assert 'new_callback2' in output[0] | |||||
assert output[0].count('???')==1 | |||||
finally: | |||||
rank_zero_rm(path) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | ||||